You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/07/02 19:24:28 UTC

[incubator-tvm] branch master updated: [Target] Migrate data structure of TargetNode (#5960)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 6ce8a1c  [Target] Migrate data structure of TargetNode (#5960)
6ce8a1c is described below

commit 6ce8a1cb5fbddf1acf0ed9e00eef2d3e5071f86f
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Thu Jul 2 12:23:57 2020 -0700

    [Target] Migrate data structure of TargetNode (#5960)
---
 include/tvm/target/target.h                        |  75 +++--
 include/tvm/target/target_id.h                     |  55 ++++
 python/tvm/autotvm/tophub.py                       |   9 +-
 python/tvm/driver/build_module.py                  |   4 +-
 python/tvm/relay/op/strategy/cuda.py               |  22 +-
 python/tvm/relay/op/strategy/rocm.py               |   2 +-
 python/tvm/relay/qnn/op/legalizations.py           |   7 +-
 python/tvm/relay/quantize/_calibrate.py            |   2 +-
 python/tvm/target/__init__.py                      |   2 +-
 python/tvm/target/target.py                        |  76 ++---
 src/driver/driver_api.cc                           |  11 +-
 src/relay/backend/build_module.cc                  |   4 +-
 src/target/codegen.cc                              |   7 +-
 src/target/generic_func.cc                         |   2 +-
 src/target/source/codegen_aocl.cc                  |   5 +-
 src/target/source/codegen_vhls.cc                  |   3 +-
 src/target/target.cc                               | 248 +++++-----------
 src/target/target_id.cc                            | 329 +++++++++++++++++++--
 .../schedule_postproc_rewrite_for_tensor_core.cc   |   2 +-
 src/tir/analysis/verify_memory.cc                  |   2 +-
 src/tir/transforms/lower_custom_datatypes.cc       |   2 +-
 src/tir/transforms/lower_intrin.cc                 |  13 +-
 src/tir/transforms/lower_thread_allreduce.cc       |   7 +-
 src/tir/transforms/lower_warp_memory.cc            |   3 +-
 src/tir/transforms/make_packed_api.cc              |   2 +-
 tests/micro/test_runtime_micro_on_arm.py           |   2 +-
 tests/python/unittest/test_runtime_micro.py        |   2 +-
 tests/python/unittest/test_target_target.py        |   4 +-
 topi/include/topi/cuda/dense.h                     |   4 +-
 topi/include/topi/cuda/injective.h                 |   2 +-
 topi/include/topi/cuda/pooling.h                   |   2 +-
 topi/include/topi/cuda/reduction.h                 |   4 +-
 topi/include/topi/rocm/dense.h                     |   4 +-
 topi/python/topi/arm_cpu/conv2d_gemm.py            |   2 +-
 topi/python/topi/cuda/batch_matmul.py              |   2 +-
 topi/python/topi/cuda/conv1d.py                    |   4 +-
 topi/python/topi/cuda/conv1d_transpose_ncw.py      |   2 +-
 topi/python/topi/cuda/conv2d_direct.py             |   4 +-
 topi/python/topi/cuda/conv2d_nhwc.py               |   2 +-
 topi/python/topi/cuda/conv2d_nhwc_tensorcore.py    |   2 +-
 topi/python/topi/cuda/conv2d_transpose_nchw.py     |   2 +-
 topi/python/topi/cuda/conv2d_winograd.py           |   2 +-
 topi/python/topi/cuda/conv3d_direct.py             |   4 +-
 topi/python/topi/cuda/conv3d_ndhwc_tensorcore.py   |   2 +-
 topi/python/topi/cuda/conv3d_winograd.py           |   4 +-
 topi/python/topi/cuda/correlation.py               |   2 +-
 topi/python/topi/cuda/deformable_conv2d.py         |   2 +-
 topi/python/topi/cuda/dense_tensorcore.py          |   2 +-
 topi/python/topi/cuda/depthwise_conv2d.py          |   6 +-
 topi/python/topi/cuda/group_conv2d_nchw.py         |   2 +-
 topi/python/topi/cuda/reduction.py                 |   2 +-
 topi/python/topi/cuda/softmax.py                   |   4 +-
 topi/python/topi/cuda/vision.py                    |   2 +-
 topi/python/topi/generic/default.py                |   2 +-
 topi/python/topi/generic/injective.py              |   2 +-
 topi/python/topi/generic/vision.py                 |   2 +-
 .../python/topi/intel_graphics/depthwise_conv2d.py |   6 +-
 57 files changed, 639 insertions(+), 345 deletions(-)

diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h
index c85349d..30ae19a 100644
--- a/include/tvm/target/target.h
+++ b/include/tvm/target/target.h
@@ -28,6 +28,7 @@
 #include <tvm/ir/transform.h>
 #include <tvm/node/container.h>
 #include <tvm/support/with.h>
+#include <tvm/target/target_id.h>
 
 #include <string>
 #include <unordered_set>
@@ -42,45 +43,50 @@ namespace tvm {
  */
 class TargetNode : public Object {
  public:
-  /*! \brief The name of the target device */
-  std::string target_name;
-  /*! \brief The name of the target device */
-  std::string device_name;
-  /*! \brief The type of the target device */
-  int device_type;
-  /*! \brief The maximum threads that a schedule should use for this device */
-  int max_num_threads = 1;
-  /*! \brief The warp size that should be used by the LowerThreadAllreduce pass */
-  int thread_warp_size = 1;
+  /*! \brief The id of the target device */
+  TargetId id;
+  /*! \brief Tag of the the target, can be empty */
+  String tag;
   /*! \brief Keys for this target */
-  Array<runtime::String> keys_array;
-  /*! \brief Options for this target */
-  Array<runtime::String> options_array;
-  /*! \brief Collection of imported libs */
-  Array<runtime::String> libs_array;
+  Array<String> keys;
+  /*! \brief Collection of attributes */
+  Map<String, ObjectRef> attrs;
 
   /*! \return the full device string to pass to codegen::Build */
   TVM_DLL const std::string& str() const;
 
   void VisitAttrs(AttrVisitor* v) {
-    v->Visit("target_name", &target_name);
-    v->Visit("device_name", &device_name);
-    v->Visit("device_type", &device_type);
-    v->Visit("max_num_threads", &max_num_threads);
-    v->Visit("thread_warp_size", &thread_warp_size);
-    v->Visit("keys_array", &keys_array);
-    v->Visit("options_array", &options_array);
-    v->Visit("libs_array", &libs_array);
+    v->Visit("id", &id);
+    v->Visit("tag", &tag);
+    v->Visit("keys_", &keys);
+    v->Visit("attrs", &attrs);
+    v->Visit("_str_repr_", &str_repr_);
   }
 
-  /*! \brief Get the keys for this target as a vector of string */
-  TVM_DLL std::vector<std::string> keys() const;
+  template <typename TObjectRef>
+  Optional<TObjectRef> GetAttr(
+      const std::string& attr_key,
+      Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) const {
+    static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
+                  "Can only call GetAttr with ObjectRef types.");
+    auto it = attrs.find(attr_key);
+    if (it != attrs.end()) {
+      return Downcast<Optional<TObjectRef>>((*it).second);
+    } else {
+      return default_value;
+    }
+  }
 
-  /*! \brief Get the options for this target as a vector of string */
-  TVM_DLL std::vector<std::string> options() const;
+  template <typename TObjectRef>
+  Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef default_value) const {
+    return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
+  }
+
+  /*! \brief Get the keys for this target as a vector of string */
+  TVM_DLL std::vector<std::string> GetKeys() const;
 
   /*! \brief Get the keys for this target as an unordered_set of string */
-  TVM_DLL std::unordered_set<std::string> libs() const;
+  TVM_DLL std::unordered_set<std::string> GetLibs() const;
 
   static constexpr const char* _type_key = "Target";
   TVM_DECLARE_FINAL_OBJECT_INFO(TargetNode, Object);
@@ -88,6 +94,7 @@ class TargetNode : public Object {
  private:
   /*! \brief Internal string repr. */
   mutable std::string str_repr_;
+  friend class Target;
 };
 
 /*!
@@ -102,7 +109,17 @@ class Target : public ObjectRef {
    * \brief Create a Target given a string
    * \param target_str the string to parse
    */
-  TVM_DLL static Target Create(const std::string& target_str);
+  TVM_DLL static Target Create(const String& target_str);
+  /*!
+   * \brief Construct a Target node from the given name and options.
+   * \param name The major target name. Should be one of
+   * {"aocl", "aocl_sw_emu", "c", "cuda", "ext_dev", "hexagon", "hybrid", "llvm",
+   *  "metal", "nvptx", "opencl", "rocm", "sdaccel", "stackvm", "vulkan"}
+   * \param options Additional options appended to the target
+   * \return The constructed Target
+   */
+  TVM_DLL static Target CreateTarget(const std::string& name,
+                                     const std::vector<std::string>& options);
   /*!
    * \brief Get the current target context from thread local storage.
    * \param allow_not_defined If the context stack is empty and this is set to true, an
diff --git a/include/tvm/target/target_id.h b/include/tvm/target/target_id.h
index 93c88c75..e8d53a3 100644
--- a/include/tvm/target/target_id.h
+++ b/include/tvm/target/target_id.h
@@ -43,6 +43,8 @@ template <typename, typename, typename>
 struct ValueTypeInfoMaker;
 }
 
+class Target;
+
 /*! \brief Perform schema validation */
 TVM_DLL void TargetValidateSchema(const Map<String, ObjectRef>& config);
 
@@ -54,6 +56,10 @@ class TargetIdNode : public Object {
  public:
   /*! \brief Name of the target id */
   String name;
+  /*! \brief Device type of target id */
+  int device_type;
+  /*! \brief Default keys of the target */
+  Array<String> default_keys;
   /*! \brief Stores the required type_key and type_index of a specific attr of a target */
   struct ValueTypeInfo {
     String type_key;
@@ -62,6 +68,14 @@ class TargetIdNode : public Object {
     std::unique_ptr<ValueTypeInfo> val;
   };
 
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("name", &name);
+    v->Visit("device_type", &device_type);
+    v->Visit("default_keys", &default_keys);
+  }
+
+  Map<String, ObjectRef> ParseAttrsFromRawString(const std::vector<std::string>& options);
+
   static constexpr const char* _type_key = "TargetId";
   TVM_DECLARE_FINAL_OBJECT_INFO(TargetIdNode, Object);
 
@@ -72,9 +86,12 @@ class TargetIdNode : public Object {
   void ValidateSchema(const Map<String, ObjectRef>& config) const;
   /*! \brief A hash table that stores the type information of each attr of the target key */
   std::unordered_map<String, ValueTypeInfo> key2vtype_;
+  /*! \brief A hash table that stores the default value of each attr of the target key */
+  std::unordered_map<String, ObjectRef> key2default_;
   /*! \brief Index used for internal lookup of attribute registry */
   uint32_t index_;
   friend void TargetValidateSchema(const Map<String, ObjectRef>&);
+  friend class Target;
   friend class TargetId;
   template <typename, typename>
   friend class AttrRegistry;
@@ -91,6 +108,7 @@ class TargetIdNode : public Object {
  */
 class TargetId : public ObjectRef {
  public:
+  TargetId() = default;
   /*! \brief Get the attribute map given the attribute name */
   template <typename ValueType>
   static inline TargetIdAttrMap<ValueType> GetAttrMap(const String& attr_name);
@@ -110,6 +128,7 @@ class TargetId : public ObjectRef {
   template <typename, typename>
   friend class AttrRegistry;
   friend class TargetIdRegEntry;
+  friend class Target;
 };
 
 /*!
@@ -149,12 +168,30 @@ class TargetIdRegEntry {
   inline TargetIdRegEntry& set_attr(const String& attr_name, const ValueType& value,
                                     int plevel = 10);
   /*!
+   * \brief Set DLPack's device_type the target
+   * \param device_type Device type
+   */
+  inline TargetIdRegEntry& set_device_type(int device_type);
+  /*!
+   * \brief Set DLPack's device_type the target
+   * \param keys The default keys
+   */
+  inline TargetIdRegEntry& set_default_keys(std::vector<String> keys);
+  /*!
    * \brief Register a valid configuration option and its ValueType for validation
    * \param key The configuration key
    * \tparam ValueType The value type to be registered
    */
   template <typename ValueType>
   inline TargetIdRegEntry& add_attr_option(const String& key);
+  /*!
+   * \brief Register a valid configuration option and its ValueType for validation
+   * \param key The configuration key
+   * \param default_value The default value of the key
+   * \tparam ValueType The value type to be registered
+   */
+  template <typename ValueType>
+  inline TargetIdRegEntry& add_attr_option(const String& key, ObjectRef default_value);
   /*! \brief Set name of the TargetId to be the same as registry if it is empty */
   inline TargetIdRegEntry& set_name();
   /*!
@@ -286,6 +323,16 @@ inline TargetIdRegEntry& TargetIdRegEntry::set_attr(const String& attr_name, con
   return *this;
 }
 
+inline TargetIdRegEntry& TargetIdRegEntry::set_device_type(int device_type) {
+  id_->device_type = device_type;
+  return *this;
+}
+
+inline TargetIdRegEntry& TargetIdRegEntry::set_default_keys(std::vector<String> keys) {
+  id_->default_keys = keys;
+  return *this;
+}
+
 template <typename ValueType>
 inline TargetIdRegEntry& TargetIdRegEntry::add_attr_option(const String& key) {
   CHECK(!id_->key2vtype_.count(key))
@@ -294,6 +341,14 @@ inline TargetIdRegEntry& TargetIdRegEntry::add_attr_option(const String& key) {
   return *this;
 }
 
+template <typename ValueType>
+inline TargetIdRegEntry& TargetIdRegEntry::add_attr_option(const String& key,
+                                                           ObjectRef default_value) {
+  add_attr_option<ValueType>(key);
+  id_->key2default_[key] = default_value;
+  return *this;
+}
+
 inline TargetIdRegEntry& TargetIdRegEntry::set_name() {
   if (id_->name.empty()) {
     id_->name = name;
diff --git a/python/tvm/autotvm/tophub.py b/python/tvm/autotvm/tophub.py
index a11c16b..c0e4eed 100644
--- a/python/tvm/autotvm/tophub.py
+++ b/python/tvm/autotvm/tophub.py
@@ -103,11 +103,10 @@ def context(target, extra_files=None):
             tgt = _target.create(tgt)
 
         possible_names = []
-        for opt in tgt.options:
-            if opt.startswith("-device"):
-                device = _alias(opt[8:])
-                possible_names.append(device)
-        possible_names.append(tgt.target_name)
+        device = tgt.attrs.get("device", "")
+        if device != "":
+            possible_names.append(_alias(device))
+        possible_names.append(tgt.id.name)
 
         all_packages = list(PACKAGE_VERSION.keys())
         for name in possible_names:
diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py
index 47e9a81..b107000 100644
--- a/python/tvm/driver/build_module.py
+++ b/python/tvm/driver/build_module.py
@@ -238,7 +238,7 @@ def _build_for_device(input_mod, target, target_host):
     """
     target = _target.create(target)
     target_host = _target.create(target_host)
-    device_type = ndarray.context(target.target_name, 0).device_type
+    device_type = ndarray.context(target.id.name, 0).device_type
 
     mod_mixed = input_mod
     mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed)
@@ -402,7 +402,7 @@ def build(inputs,
     if not target_host:
         for tar, _ in target_input_mod.items():
             tar = _target.create(tar)
-            device_type = ndarray.context(tar.target_name, 0).device_type
+            device_type = ndarray.context(tar.id.name, 0).device_type
             if device_type == ndarray.cpu(0).device_type:
                 target_host = tar
                 break
diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py
index a1c88b8..d626a9d 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -68,7 +68,7 @@ def softmax_strategy_cuda(attrs, inputs, out_type, target):
         wrap_compute_softmax(topi.nn.softmax),
         wrap_topi_schedule(topi.cuda.schedule_softmax),
         name="softmax.cuda")
-    if target.target_name == "cuda" and "cudnn" in target.libs:
+    if target.id.name == "cuda" and "cudnn" in target.libs:
         strategy.add_implementation(
             wrap_compute_softmax(topi.cuda.softmax_cudnn),
             wrap_topi_schedule(topi.cuda.schedule_softmax_cudnn),
@@ -145,7 +145,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
                                                                              dilation_h, dilation_w,
                                                                              pre_flag=False)
             if judge_winograd_shape:
-                if target.target_name == "cuda" and \
+                if target.id.name == "cuda" and \
                     nvcc.have_tensorcore(tvm.gpu(0).compute_version) and \
                     judge_winograd_tensorcore:
                     strategy.add_implementation(
@@ -162,7 +162,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
                             topi.cuda.schedule_conv2d_nhwc_winograd_direct),
                         name="conv2d_nhwc_winograd_direct.cuda",
                         plevel=5)
-            if target.target_name == "cuda":
+            if target.id.name == "cuda":
                 if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
                     if (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
                             (N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
@@ -181,7 +181,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
         else:
             raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout))
         # add cudnn implementation
-        if target.target_name == "cuda" and "cudnn" in target.libs:
+        if target.id.name == "cuda" and "cudnn" in target.libs:
             if layout in ["NCHW", "NHWC"] and padding[0] == padding[2] and \
                     padding[1] == padding[3]:
                 strategy.add_implementation(
@@ -209,7 +209,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
     else: # group_conv2d
         # add cudnn implementation, if any
         cudnn_impl = False
-        if target.target_name == "cuda" and "cudnn" in target.libs:
+        if target.id.name == "cuda" and "cudnn" in target.libs:
             if layout in ["NCHW", "NHWC"] and padding[0] == padding[2] and \
                     padding[1] == padding[3]:
                 strategy.add_implementation(
@@ -264,7 +264,7 @@ def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty
                                                       padding, stride_h, stride_w,
                                                       dilation_h, dilation_w,
                                                       pre_flag=True)
-        if target.target_name == "cuda" and \
+        if target.id.name == "cuda" and \
             nvcc.have_tensorcore(tvm.gpu(0).compute_version) and \
             judge_winograd_tensorcore:
             strategy.add_implementation(
@@ -362,7 +362,7 @@ def conv3d_strategy_cuda(attrs, inputs, out_type, target):
             plevel=10)
         N, _, _, _, _ = get_const_tuple(data.shape)
         _, _, _, CI, CO = get_const_tuple(kernel.shape)
-        if target.target_name == "cuda":
+        if target.id.name == "cuda":
             if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
                 if (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
                 (N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
@@ -373,7 +373,7 @@ def conv3d_strategy_cuda(attrs, inputs, out_type, target):
                         name="conv3d_ndhwc_tensorcore.cuda",
                         plevel=20)
 
-    if target.target_name == "cuda" and "cudnn" in target.libs:
+    if target.id.name == "cuda" and "cudnn" in target.libs:
         strategy.add_implementation(wrap_compute_conv3d(topi.cuda.conv3d_cudnn, True),
                                     wrap_topi_schedule(topi.cuda.schedule_conv3d_cudnn),
                                     name="conv3d_cudnn.cuda",
@@ -458,7 +458,7 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
                 wrap_topi_schedule(topi.cuda.schedule_dense_large_batch),
                 name="dense_large_batch.cuda",
                 plevel=5)
-        if target.target_name == "cuda":
+        if target.id.name == "cuda":
             if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
                 if(i % 16 == 0 and b % 16 == 0 and o % 16 == 0) \
                         or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0) \
@@ -468,7 +468,7 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
                         wrap_topi_schedule(topi.cuda.schedule_dense_tensorcore),
                         name="dense_tensorcore.cuda",
                         plevel=20)
-    if target.target_name == "cuda" and "cublas" in target.libs:
+    if target.id.name == "cuda" and "cublas" in target.libs:
         strategy.add_implementation(
             wrap_compute_dense(topi.cuda.dense_cublas),
             wrap_topi_schedule(topi.cuda.schedule_dense_cublas),
@@ -485,7 +485,7 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
         wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
         name="batch_matmul.cuda",
         plevel=10)
-    if target.target_name == "cuda" and "cublas" in target.libs:
+    if target.id.name == "cuda" and "cublas" in target.libs:
         strategy.add_implementation(
             wrap_compute_batch_matmul(topi.cuda.batch_matmul_cublas),
             wrap_topi_schedule(topi.generic.schedule_extern),
diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py
index b1213f1..a80b6ca 100644
--- a/python/tvm/relay/op/strategy/rocm.py
+++ b/python/tvm/relay/op/strategy/rocm.py
@@ -127,7 +127,7 @@ def dense_strategy_rocm(attrs, inputs, out_type, target):
         wrap_compute_dense(topi.rocm.dense),
         wrap_topi_schedule(topi.rocm.schedule_dense),
         name="dense.rocm")
-    if target.target_name == "rocm" and "rocblas" in target.libs:
+    if target.id.name == "rocm" and "rocblas" in target.libs:
         assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported."
         strategy.add_implementation(
             wrap_compute_dense(topi.rocm.dense_rocblas),
diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py
index 7246214..00866e0 100644
--- a/python/tvm/relay/qnn/op/legalizations.py
+++ b/python/tvm/relay/qnn/op/legalizations.py
@@ -229,18 +229,17 @@ def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op):
 def is_fast_int8_on_intel():
     """ Checks whether the hardware has support for fast Int8 arithmetic operations. """
     target = tvm.target.Target.current(allow_none=False)
-    intel_supported_arches = {'-mcpu=skylake-avx512', '-mcpu=cascadelake'}
-    return intel_supported_arches.intersection(set(target.options))
+    return target.mcpu in {'skylake-avx512', 'cascadelake'}
 
 def is_fast_int8_on_arm():
     """ Checks whether the hardware has support for fast Int8 arithmetic operations. """
     target = tvm.target.Target.current(allow_none=False)
-    return '+v8.2a,+dotprod' in ' '.join(target.options)
+    return '+v8.2a,+dotprod' in target.mattr
 
 def is_aarch64_arm():
     """ Checks whether we are compiling for an AArch64 target. """
     target = tvm.target.Target.current(allow_none=False)
-    return 'aarch64' in ' '.join(target.options)
+    return 'aarch64' in target.attrs.get("target", "")
 
 ########################
 # ARM CPU legalizations.
diff --git a/python/tvm/relay/quantize/_calibrate.py b/python/tvm/relay/quantize/_calibrate.py
index 9590e87..74a6f60 100644
--- a/python/tvm/relay/quantize/_calibrate.py
+++ b/python/tvm/relay/quantize/_calibrate.py
@@ -39,7 +39,7 @@ def _get_profile_runtime(mod):
 
     if tvm.target.Target.current():
         target = tvm.target.Target.current()
-        ctx = tvm.context(target.target_name)
+        ctx = tvm.context(target.id.name)
     else:
         target = 'llvm'
         ctx = tvm.context(target)
diff --git a/python/tvm/target/__init__.py b/python/tvm/target/__init__.py
index 2553fed..18a9e7e 100644
--- a/python/tvm/target/__init__.py
+++ b/python/tvm/target/__init__.py
@@ -16,7 +16,7 @@
 # under the License.
 """Target description and codgen module.
 
-TVM's target string is in fomat ``<target_name> [-option=value]...``.
+TVM's target string is in fomat ``<target_id> [-option=value]...``.
 
 Note
 ----
diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py
index 3335e12..a2a4501 100644
--- a/python/tvm/target/target.py
+++ b/python/tvm/target/target.py
@@ -23,6 +23,12 @@ from . import _ffi_api
 
 
 @tvm._ffi.register_object
+class TargetId(Object):
+    """Id of a compilation target
+    """
+
+
+@tvm._ffi.register_object
 class Target(Object):
     """Target device information, use through TVM API.
 
@@ -41,45 +47,15 @@ class Target(Object):
         # Always override new to enable class
         obj = Object.__new__(cls)
         obj._keys = None
-        obj._options = None
         obj._libs = None
         return obj
 
     @property
     def keys(self):
         if not self._keys:
-            self._keys = [str(k) for k in self.keys_array]
+            self._keys = [str(k) for k in self.keys_]
         return self._keys
 
-    @property
-    def options(self):
-        if not self._options:
-            self._options = [str(o) for o in self.options_array]
-        return self._options
-
-    @property
-    def libs(self):
-        if not self._libs:
-            self._libs = [str(l) for l in self.libs_array]
-        return self._libs
-
-    @property
-    def model(self):
-        for opt in self.options_array:
-            if opt.startswith('-model='):
-                return opt[7:]
-        return 'unknown'
-
-    @property
-    def mcpu(self):
-        """Returns the mcpu from the target if it exists."""
-        mcpu = ''
-        if self.options is not None:
-            for opt in self.options:
-                if 'mcpu' in opt:
-                    mcpu = opt.split('=')[1]
-        return mcpu
-
     def __enter__(self):
         _ffi_api.EnterTargetScope(self)
         return self
@@ -102,6 +78,40 @@ class Target(Object):
         """
         return _ffi_api.GetCurrentTarget(allow_none)
 
+    @property
+    def max_num_threads(self):
+        return int(self.attrs["max_num_threads"])
+
+    @property
+    def thread_warp_size(self):
+        return int(self.attrs["thread_warp_size"])
+
+    @property
+    def device_name(self):
+        return str(self.attrs.get("device", ""))
+
+    @property
+    def model(self):
+        """Returns model from the target if it exists."""
+        return str(self.attrs.get("model", "unknown"))
+
+    @property
+    def mcpu(self):
+        """Returns the mcpu from the target if it exists."""
+        return str(self.attrs.get("mcpu", ""))
+
+    @property
+    def mattr(self):
+        """Returns the mattr from the target if it exists."""
+        return self.attrs.get("mattr", "")
+
+    @property
+    def libs(self):
+        if not self._libs:
+            self._libs = list(self.attrs.get("libs", ""))
+        return self._libs
+
+
 
 def _merge_opts(opts, new_opts):
     """Helper function to merge options"""
@@ -167,7 +177,7 @@ def intel_graphics(model='unknown', options=None):
     options : str or list of str
         Additional options
     """
-    opts = ["-device=intel_graphics", '-model=%s' % model]
+    opts = ["-device=intel_graphics", "-model=%s" % model, "-thread_warp_size=16"]
     opts = _merge_opts(opts, options)
     return _ffi_api.TargetCreate("opencl", *opts)
 
@@ -216,7 +226,7 @@ def rasp(options=None):
 
 
 def vta(model='unknown', options=None):
-    opts = ["-device=vta", '-keys=cpu', '-model=%s' % model]
+    opts = ["-device=vta", '-keys=vta,cpu', '-model=%s' % model]
     opts = _merge_opts(opts, options)
     ret = _ffi_api.TargetCreate("ext_dev", *opts)
     return ret
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index e796f49..2c08ea1 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -56,7 +56,7 @@ bool LLVMEnabled() {
 
 /*! \return The default host target for a given device target */
 Target DefaultTargetHost(Target target) {
-  if (target.defined() && target->device_type == kDLCPU) {
+  if (target.defined() && target->id->device_type == kDLCPU) {
     return target;
   } else {
     if (LLVMEnabled()) {
@@ -232,14 +232,14 @@ std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule mod_mixed, const Target
   auto mdevice = opt_device(mod_mixed);
 
   // some final misc checks.
-  auto keys = target->keys();
+  auto keys = target->GetKeys();
   bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end();
   if (target_is_gpu && mdevice->functions.size() == 0) {
     LOG(WARNING) << "Specified target " << target->str()
                  << " but cannot find device code. Did you forget to bind?";
   }
 
-  if (target->device_type == target::llvm()->device_type && target_host == target) {
+  if (target->id->device_type == kDLCPU && target_host == target) {
     CHECK(mdevice->functions.empty()) << "No device code should be generated when target "
                                       << "and host_target are both llvm target."
                                       << "\n";
@@ -256,7 +256,7 @@ runtime::Module build(const Map<Target, IRModule>& inputs, const Target& target_
   Target target_host_val = target_host;
   if (!target_host.defined()) {
     for (const auto& it : inputs) {
-      if (it.first->device_type == kDLCPU || it.first->device_type == kDLMicroDev) {
+      if (it.first->id->device_type == kDLCPU || it.first->id->device_type == kDLMicroDev) {
         target_host_val = it.first;
         break;
       }
@@ -295,7 +295,8 @@ runtime::Module build(const Map<String, IRModule>& inputs, const Target& target_
   Map<Target, IRModule> updated_input;
   for (const auto& it : inputs) {
     auto target = Target::Create(it.first);
-    if (target->device_name == "vta") {
+    Optional<String> device = target->GetAttr<String>("device");
+    if (device.defined() && device.value() == "vta") {
       target = Target::Create("ext_dev");
     }
     updated_input.Set(target, it.second);
diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc
index 34c3487..b589bcc 100644
--- a/src/relay/backend/build_module.cc
+++ b/src/relay/backend/build_module.cc
@@ -441,7 +441,7 @@ class RelayBuildModule : public runtime::ModuleNode {
       if (!target_host.defined())
         target_host = (pf != nullptr) ? target::llvm() : target::stackvm();
 
-      if (target_host.defined() && target_host->target_name == "llvm") {
+      if (target_host.defined() && target_host->id->name == "llvm") {
         // If we can decide the target is LLVM, we then create an empty LLVM module.
         ret_.mod = (*pf)(target_host->str(), "empty_module");
       } else {
@@ -467,7 +467,7 @@ class RelayBuildModule : public runtime::ModuleNode {
     Target target_host = target_host_;
     if (!target_host_.defined()) {
       for (const auto& it : targets_) {
-        if (it.second->device_type == kDLCPU) {
+        if (it.second->id->device_type == kDLCPU) {
           target_host = it.second;
           break;
         }
diff --git a/src/target/codegen.cc b/src/target/codegen.cc
index d0a3156..52bd1c2 100644
--- a/src/target/codegen.cc
+++ b/src/target/codegen.cc
@@ -47,7 +47,12 @@ runtime::Module Build(IRModule mod, const Target& target) {
           .value()) {
     mod = tir::transform::SkipAssert()(mod);
   }
-  std::string build_f_name = "target.build." + target->target_name;
+  std::string build_f_name;
+  if (target->id->name == "micro_dev") {
+    build_f_name = "target.build.c";
+  } else {
+    build_f_name = "target.build." + target->id->name;
+  }
   // the build function.
   const PackedFunc* bf = runtime::Registry::Get(build_f_name);
   CHECK(bf != nullptr) << "target.build." << target << " is not enabled";
diff --git a/src/target/generic_func.cc b/src/target/generic_func.cc
index 9ad9f56..b5842ee 100644
--- a/src/target/generic_func.cc
+++ b/src/target/generic_func.cc
@@ -102,7 +102,7 @@ void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const {
   PackedFunc func;
 
   if (target.defined()) {
-    for (auto& k : target->keys()) {
+    for (auto& k : target->GetKeys()) {
       auto iter = node->dispatch_dict_.find(k);
       if (iter != node->dispatch_dict_.end()) {
         func = iter->second;
diff --git a/src/target/source/codegen_aocl.cc b/src/target/source/codegen_aocl.cc
index 2b77869..597fd37 100644
--- a/src/target/source/codegen_aocl.cc
+++ b/src/target/source/codegen_aocl.cc
@@ -62,8 +62,9 @@ runtime::Module BuildAOCL(IRModule mod, std::string target_str, bool emulation)
   // AOCL supports fp64.
   cmd += " -Dcl_khr_fp64";
   Target target = Target::Create(target_str);
-  if (target->device_name != "") {
-    cmd += " -board=" + target->device_name;
+  Optional<String> device = target->GetAttr<String>("device");
+  if (device.defined()) {
+    cmd += " -board=" + device.value();
   }
   if (emulation) {
     cmd += " -march=emulator";
diff --git a/src/target/source/codegen_vhls.cc b/src/target/source/codegen_vhls.cc
index e60e1f5..3d77dda 100644
--- a/src/target/source/codegen_vhls.cc
+++ b/src/target/source/codegen_vhls.cc
@@ -179,7 +179,8 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) {
   std::string xclbin;
   if (const auto* f = Registry::Get("tvm_callback_sdaccel_compile")) {
     Target target = Target::Create(target_str);
-    xclbin = (*f)(kernel_info, target->device_name).operator std::string();
+    String device = target->GetAttr<String>("device", "").value();
+    xclbin = (*f)(kernel_info, device).operator std::string();
   } else {
     LOG(FATAL) << "Cannot compile Vivado HLS code.";
   }
diff --git a/src/target/target.cc b/src/target/target.cc
index 2104c2e..5c61867 100644
--- a/src/target/target.cc
+++ b/src/target/target.cc
@@ -24,6 +24,7 @@
 #include <tvm/node/repr_printer.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/target/target.h>
+#include <tvm/target/target_id.h>
 #include <tvm/tir/expr.h>
 
 #include <algorithm>
@@ -35,6 +36,41 @@ using runtime::PackedFunc;
 using runtime::TVMArgs;
 using runtime::TVMRetValue;
 
+Target Target::CreateTarget(const std::string& name, const std::vector<std::string>& options) {
+  TargetId id = TargetId::Get(name);
+  ObjectPtr<TargetNode> target = make_object<TargetNode>();
+  target->id = id;
+  // tag is always empty
+  target->tag = "";
+  // parse attrs
+  target->attrs = id->ParseAttrsFromRawString(options);
+  String device_name = target->GetAttr<String>("device", "").value();
+  // create string representation
+  {
+    std::ostringstream str_repr;
+    str_repr << name;
+    for (const auto& s : options) {
+      str_repr << ' ' << s;
+    }
+    target->str_repr_ = str_repr.str();
+  }
+  // set up keys
+  {
+    // user provided keys
+    Array<String> keys = target->GetAttr<Array<String>>("keys").value_or({});
+    // add `device_name`
+    if (!device_name.empty()) {
+      keys.push_back(device_name);
+    }
+    // add default keys
+    for (const auto& key : target->id->default_keys) {
+      keys.push_back(key);
+    }
+    target->keys = std::move(keys);
+  }
+  return Target(target);
+}
+
 TVM_REGISTER_NODE_TYPE(TargetNode);
 
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
@@ -43,119 +79,15 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
       p->stream << op->str();
     });
 
-/*!
- * \brief Construct a Target node from the given name and options.
- * \param target_name The major target name. Should be one of
- * {"aocl", "aocl_sw_emu", "c", "cuda", "ext_dev", "hexagon", "hybrid", "llvm",
- *  "metal", "nvptx", "opencl", "rocm", "sdaccel", "stackvm", "vulkan"}
- * \param options Additional options appended to the target
- * \return The constructed Target
- */
-Target CreateTarget(const std::string& target_name, const std::vector<std::string>& options) {
-  auto t = make_object<TargetNode>();
-  t->target_name = target_name;
-
-  std::string libs_flag = "-libs=";
-  std::string device_flag = "-device=";
-  std::string keys_flag = "-keys=";
-  for (auto& item : options) {
-    t->options_array.push_back(item);
-
-    if (item.find(libs_flag) == 0) {
-      std::stringstream ss(item.substr(libs_flag.length()));
-      std::string lib_item;
-      while (std::getline(ss, lib_item, ',')) {
-        t->libs_array.push_back(lib_item);
-      }
-    } else if (item.find(device_flag) == 0) {
-      t->device_name = item.substr(device_flag.length());
-      t->keys_array.push_back(t->device_name);
-    } else if (item.find(keys_flag) == 0) {
-      std::stringstream ss(item.substr(keys_flag.length()));
-      std::string key_item;
-      while (std::getline(ss, key_item, ',')) {
-        t->keys_array.push_back(key_item);
-      }
-    }
-  }
-
-  if (t->device_name.length() > 0) {
-    t->keys_array.push_back(t->device_name);
-  }
-  t->device_type = kDLCPU;
-  t->thread_warp_size = 1;
-  if (target_name == "c" && t->device_name == "micro_dev") {
-    t->device_type = kDLMicroDev;
-  } else if (target_name == "c" || target_name == "llvm") {
-    t->keys_array.push_back("cpu");
-  } else if (target_name == "cuda" || target_name == "nvptx") {
-    t->device_type = kDLGPU;
-    t->keys_array.push_back("cuda");
-    t->keys_array.push_back("gpu");
-    t->max_num_threads = 1024;
-    t->thread_warp_size = 32;
-  } else if (target_name == "rocm" || target_name == "opencl") {
-    // For now assume rocm schedule for opencl
-    if (target_name == "opencl") {
-      t->device_type = kDLOpenCL;
-    } else {  // rocm
-      t->device_type = kDLROCM;
-      t->thread_warp_size = 64;
-    }
-    t->keys_array.push_back(target_name);
-    t->keys_array.push_back("gpu");
-    t->max_num_threads = 256;
-    if (t->device_name == "intel_graphics") {
-      t->thread_warp_size = 16;
-    }
-  } else if (target_name == "metal" || target_name == "vulkan" || target_name == "webgpu") {
-    if (target_name == "metal") {
-      t->device_type = kDLMetal;
-    } else if (target_name == "vulkan") {
-      t->device_type = kDLVulkan;
-    } else {
-      t->device_type = kDLWebGPU;
-    }
-    t->keys_array.push_back(target_name);
-    t->keys_array.push_back("gpu");
-    t->max_num_threads = 256;
-  } else if (target_name == "sdaccel") {
-    t->device_type = kDLOpenCL;
-    t->keys_array.push_back("sdaccel");
-    t->keys_array.push_back("hls");
-  } else if (target_name == "aocl" || target_name == "aocl_sw_emu") {
-    t->device_type = kDLAOCL;
-    t->keys_array.push_back("aocl");
-    t->keys_array.push_back("hls");
-  } else if (target_name == "stackvm") {
-    t->device_type = kDLCPU;
-  } else if (target_name == "ext_dev") {
-    t->device_type = kDLExtDev;
-  } else if (target_name == "hybrid") {
-    t->device_type = kDLCPU;
-  } else if (target_name == "hexagon") {
-    t->keys_array.push_back("hexagon");
-    t->device_type = kDLHexagon;
-  } else if (target_name == "webgpu") {
-    t->keys_array.push_back("webgpu");
-    t->device_type = kDLWebGPU;
-  } else {
-    LOG(ERROR) << "Unknown target name " << target_name << "; falling back to stackvm";
-    return target::stackvm();
-  }
-
-  return Target(t);
-}
-
 TVM_REGISTER_GLOBAL("target.TargetCreate").set_body([](TVMArgs args, TVMRetValue* ret) {
-  std::string target_name = args[0];
+  std::string name = args[0];
   std::vector<std::string> options;
   for (int i = 1; i < args.num_args; ++i) {
     std::string arg = args[i];
     options.push_back(arg);
   }
 
-  *ret = CreateTarget(target_name, options);
+  *ret = Target::CreateTarget(name, options);
 });
 
 TVM_REGISTER_GLOBAL("target.TargetFromString").set_body([](TVMArgs args, TVMRetValue* ret) {
@@ -163,38 +95,28 @@ TVM_REGISTER_GLOBAL("target.TargetFromString").set_body([](TVMArgs args, TVMRetV
   *ret = Target::Create(target_str);
 });
 
-std::vector<std::string> TargetNode::keys() const {
+std::vector<std::string> TargetNode::GetKeys() const {
   std::vector<std::string> result;
-  for (auto& expr : keys_array) {
+  for (auto& expr : keys) {
     result.push_back(expr);
   }
   return result;
 }
 
-std::vector<std::string> TargetNode::options() const {
-  std::vector<std::string> result;
-  for (auto& expr : options_array) {
-    result.push_back(expr);
+std::unordered_set<std::string> TargetNode::GetLibs() const {
+  Optional<Array<String>> libs = this->GetAttr<Array<String>>("libs");
+  if (!libs.defined()) {
+    return {};
   }
-  return result;
-}
-
-std::unordered_set<std::string> TargetNode::libs() const {
   std::unordered_set<std::string> result;
-  for (auto& expr : libs_array) {
-    result.insert(expr);
+  for (const auto& item : libs.value()) {
+    result.insert(item);
   }
   return result;
 }
 
 const std::string& TargetNode::str() const {
-  if (str_repr_.length() != 0) return str_repr_;
-  std::ostringstream result;
-  result << target_name;
-  for (const auto& x : options()) {
-    result << " " << x;
-  }
-  str_repr_ = result.str();
+  CHECK(!str_repr_.empty());
   return str_repr_;
 }
 
@@ -202,39 +124,14 @@ bool StartsWith(const std::string& str, const std::string& pattern) {
   return str.compare(0, pattern.length(), pattern) == 0;
 }
 
-std::string GetDeviceName(const std::string& target_str) {
-  std::istringstream ss(target_str);
-  std::string target_name;
-  ss >> target_name;
-
-  std::string item;
-  while (ss >> item) {
-    if (StartsWith(item, "-device=")) {
-      return item.substr(std::string("-device=").length());
-    }
+Target Target::Create(const String& target_str) {
+  std::vector<std::string> splits;
+  std::istringstream is(target_str);
+  for (std::string s; is >> s; splits.push_back(s)) {
   }
-
-  return "";
-}
-
-Target Target::Create(const std::string& target_str) {
-  if (target_str.length() == 0) {
-    LOG(ERROR) << "target_str must not be empty";
-  }
-
-  std::istringstream ss(target_str);
-  std::string target_name;
-
-  ss >> target_name;
-  auto device_name = GetDeviceName(target_str);
-
-  std::vector<std::string> options;
-  std::string item;
-  while (ss >> item) {
-    options.push_back(item);
-  }
-
-  return CreateTarget(target_name, options);
+  CHECK(!splits.empty()) << "ValueError: Cannot parse empty target string: \"" << target_str
+                         << "\"";
+  return CreateTarget(splits[0], {splits.begin() + 1, splits.end()});
 }
 
 /*! \brief Entry to hold the Target context stack. */
@@ -290,28 +187,45 @@ std::vector<std::string> MergeOptions(std::vector<std::string> opts,
   return opts;
 }
 
-Target llvm(const std::vector<std::string>& options) { return CreateTarget("llvm", options); }
+Target llvm(const std::vector<std::string>& options) {
+  return Target::CreateTarget("llvm", options);
+}
 
-Target cuda(const std::vector<std::string>& options) { return CreateTarget("cuda", options); }
+Target cuda(const std::vector<std::string>& options) {
+  return Target::CreateTarget("cuda", options);
+}
 
-Target rocm(const std::vector<std::string>& options) { return CreateTarget("rocm", options); }
+Target rocm(const std::vector<std::string>& options) {
+  return Target::CreateTarget("rocm", options);
+}
 
-Target opencl(const std::vector<std::string>& options) { return CreateTarget("opencl", options); }
+Target opencl(const std::vector<std::string>& options) {
+  return Target::CreateTarget("opencl", options);
+}
 
-Target metal(const std::vector<std::string>& options) { return CreateTarget("metal", options); }
+Target metal(const std::vector<std::string>& options) {
+  return Target::CreateTarget("metal", options);
+}
 
 Target mali(const std::vector<std::string>& options) {
-  return CreateTarget("opencl", MergeOptions(options, {"-device=mali"}));
+  return Target::CreateTarget("opencl", MergeOptions(options, {"-device=mali"}));
 }
 
 Target intel_graphics(const std::vector<std::string>& options) {
-  return CreateTarget("opencl", MergeOptions(options, {"-device=intel_graphics"}));
+  return Target::CreateTarget(
+      "opencl", MergeOptions(options, {"-device=intel_graphics", "-thread_warp_size=16"}));
 }
 
-Target stackvm(const std::vector<std::string>& options) { return CreateTarget("stackvm", options); }
+Target stackvm(const std::vector<std::string>& options) {
+  return Target::CreateTarget("stackvm", options);
+}
 
-Target ext_dev(const std::vector<std::string>& options) { return CreateTarget("ext_dev", options); }
+Target ext_dev(const std::vector<std::string>& options) {
+  return Target::CreateTarget("ext_dev", options);
+}
 
-Target hexagon(const std::vector<std::string>& options) { return CreateTarget("hexagon", options); }
+Target hexagon(const std::vector<std::string>& options) {
+  return Target::CreateTarget("hexagon", options);
+}
 }  // namespace target
 }  // namespace tvm
diff --git a/src/target/target_id.cc b/src/target/target_id.cc
index dc1255c..faecf03 100644
--- a/src/target/target_id.cc
+++ b/src/target/target_id.cc
@@ -28,6 +28,14 @@
 
 namespace tvm {
 
+TVM_REGISTER_NODE_TYPE(TargetIdNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<TargetIdNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const TargetIdNode*>(node.get());
+      p->stream << op->name;
+    });
+
 using TargetIdRegistry = AttrRegistry<TargetIdRegEntry, TargetId>;
 
 TargetIdRegEntry& TargetIdRegEntry::RegisterOrGet(const String& target_id_name) {
@@ -45,14 +53,14 @@ const AttrRegistryMapContainerMap<TargetId>& TargetId::GetAttrMapContainer(
 
 const TargetId& TargetId::Get(const String& target_id_name) {
   const TargetIdRegEntry* reg = TargetIdRegistry::Global()->Get(target_id_name);
-  CHECK(reg != nullptr) << "TargetId " << target_id_name << " is not registered";
+  CHECK(reg != nullptr) << "ValueError: TargetId \"" << target_id_name << "\" is not registered";
   return reg->id_;
 }
 
 void VerifyTypeInfo(const ObjectRef& obj, const TargetIdNode::ValueTypeInfo& info) {
   CHECK(obj.defined()) << "Object is None";
   if (!runtime::ObjectInternal::DerivedFrom(obj.get(), info.type_index)) {
-    LOG(FATAL) << "AttributeError: expect type " << info.type_key << " but get "
+    LOG(FATAL) << "AttributeError: expect type \"" << info.type_key << "\" but get "
                << obj->GetTypeKey();
     throw;
   }
@@ -74,16 +82,16 @@ void VerifyTypeInfo(const ObjectRef& obj, const TargetIdNode::ValueTypeInfo& inf
       try {
         VerifyTypeInfo(kv.first, *info.key);
       } catch (const tvm::Error& e) {
-        LOG(FATAL) << "The key of map failed type checking, where key = " << kv.first
-                   << ", value = " << kv.second << ", and the error is:\n"
+        LOG(FATAL) << "The key of map failed type checking, where key = \"" << kv.first
+                   << "\", value = \"" << kv.second << "\", and the error is:\n"
                    << e.what();
         throw;
       }
       try {
         VerifyTypeInfo(kv.second, *info.val);
       } catch (const tvm::Error& e) {
-        LOG(FATAL) << "The value of map failed type checking, where key = " << kv.first
-                   << ", value = " << kv.second << ", and the error is:\n"
+        LOG(FATAL) << "The value of map failed type checking, where key = \"" << kv.first
+                   << "\", value = \"" << kv.second << "\", and the error is:\n"
                    << e.what();
         throw;
       }
@@ -98,16 +106,18 @@ void TargetIdNode::ValidateSchema(const Map<String, ObjectRef>& config) const {
     const ObjectRef& obj = kv.second;
     if (name == kTargetId) {
       CHECK(obj->IsInstance<StringObj>())
-          << "AttributeError: \"id\" is not a string, but its type is " << obj->GetTypeKey();
+          << "AttributeError: \"id\" is not a string, but its type is \"" << obj->GetTypeKey()
+          << "\"";
       CHECK(Downcast<String>(obj) == this->name)
-          << "AttributeError: \"id\" = " << obj << " is inconsistent with TargetId " << this->name;
+          << "AttributeError: \"id\" = \"" << obj << "\" is inconsistent with TargetId \""
+          << this->name << "\"";
       continue;
     }
     auto it = key2vtype_.find(name);
     if (it == key2vtype_.end()) {
       std::ostringstream os;
-      os << "AttributeError: Invalid config option, cannot recognize \'" << name
-         << "\'. Candidates are:";
+      os << "AttributeError: Invalid config option, cannot recognize \"" << name
+         << "\". Candidates are:";
       for (const auto& kv : key2vtype_) {
         os << "\n  " << kv.first;
       }
@@ -118,8 +128,8 @@ void TargetIdNode::ValidateSchema(const Map<String, ObjectRef>& config) const {
     try {
       VerifyTypeInfo(obj, info);
     } catch (const tvm::Error& e) {
-      LOG(FATAL) << "AttributeError: Schema validation failed for TargetId " << this->name
-                 << ", details:\n"
+      LOG(FATAL) << "AttributeError: Schema validation failed for TargetId \"" << this->name
+                 << "\", details:\n"
                  << e.what() << "\n"
                  << "The config is:\n"
                  << config;
@@ -130,12 +140,12 @@ void TargetIdNode::ValidateSchema(const Map<String, ObjectRef>& config) const {
 
 inline String GetId(const Map<String, ObjectRef>& target, const char* name) {
   const String kTargetId = "id";
-  CHECK(target.count(kTargetId)) << "AttributeError: \"id\" does not exist in " << name << "\n"
+  CHECK(target.count(kTargetId)) << "AttributeError: \"id\" does not exist in \"" << name << "\"\n"
                                  << name << " = " << target;
   const ObjectRef& obj = target[kTargetId];
-  CHECK(obj->IsInstance<StringObj>()) << "AttributeError: \"id\" is not a string in " << name
-                                      << ", but its type is " << obj->GetTypeKey() << "\n"
-                                      << name << " = " << target;
+  CHECK(obj->IsInstance<StringObj>()) << "AttributeError: \"id\" is not a string in \"" << name
+                                      << "\", but its type is \"" << obj->GetTypeKey() << "\"\n"
+                                      << name << " = \"" << target << '"';
   return Downcast<String>(obj);
 }
 
@@ -156,9 +166,292 @@ void TargetValidateSchema(const Map<String, ObjectRef>& config) {
       TargetId::Get(target_host_id)->ValidateSchema(target_host);
     }
   } catch (const tvm::Error& e) {
-    LOG(INFO) << e.what();
-    throw e;
+    LOG(FATAL) << "AttributeError: schedule validation fails:\n"
+               << e.what() << "\nThe configuration is:\n"
+               << config;
+  }
+}
+
+static inline size_t CountNumPrefixDashes(const std::string& s) {
+  size_t i = 0;
+  for (; i < s.length() && s[i] == '-'; ++i) {
+  }
+  return i;
+}
+
+static inline int FindUniqueSubstr(const std::string& str, const std::string& substr) {
+  size_t pos = str.find_first_of(substr);
+  if (pos == std::string::npos) {
+    return -1;
+  }
+  size_t next_pos = pos + substr.size();
+  CHECK(next_pos >= str.size() || str.find_first_of(substr, next_pos) == std::string::npos)
+      << "ValueError: At most one \"" << substr << "\" is allowed in "
+      << "the the given string \"" << str << "\"";
+  return pos;
+}
+
+static inline ObjectRef ParseScalar(uint32_t type_index, const std::string& str) {
+  std::istringstream is(str);
+  if (type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
+    int v;
+    is >> v;
+    return is.fail() ? ObjectRef(nullptr) : Integer(v);
+  } else if (type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
+    std::string v;
+    is >> v;
+    return is.fail() ? ObjectRef(nullptr) : String(v);
+  }
+  return ObjectRef(nullptr);
+}
+
+Map<String, ObjectRef> TargetIdNode::ParseAttrsFromRawString(
+    const std::vector<std::string>& options) {
+  std::unordered_map<String, ObjectRef> attrs;
+  for (size_t iter = 0, end = options.size(); iter < end;) {
+    std::string s = options[iter++];
+    // remove the prefix dashes
+    size_t n_dashes = CountNumPrefixDashes(s);
+    CHECK(0 < n_dashes && n_dashes < s.size())
+        << "ValueError: Not an attribute key \"" << s << "\"";
+    s = s.substr(n_dashes);
+    // parse name-obj pair
+    std::string name;
+    std::string obj;
+    int pos;
+    if ((pos = FindUniqueSubstr(s, "=")) != -1) {
+      // case 1. --key=value
+      name = s.substr(0, pos);
+      obj = s.substr(pos + 1);
+      CHECK(!name.empty()) << "ValueError: Empty attribute key in \"" << options[iter - 1] << "\"";
+      CHECK(!obj.empty()) << "ValueError: Empty attribute in \"" << options[iter - 1] << "\"";
+    } else if (iter < end && options[iter][0] != '-') {
+      // case 2. --key value
+      name = s;
+      obj = options[iter++];
+    } else {
+      // case 3. --boolean-key
+      name = s;
+      obj = "1";
+    }
+    // check if `name` is invalid
+    auto it = key2vtype_.find(name);
+    if (it == key2vtype_.end()) {
+      std::ostringstream os;
+      os << "AttributeError: Invalid config option, cannot recognize \'" << name
+         << "\'. Candidates are:";
+      for (const auto& kv : key2vtype_) {
+        os << "\n  " << kv.first;
+      }
+      LOG(FATAL) << os.str();
+    }
+    // then `name` is valid, let's parse them
+    // only several types are supported when parsing raw string
+    const auto& info = it->second;
+    ObjectRef parsed_obj(nullptr);
+    if (info.type_index != ArrayNode::_type_index) {
+      parsed_obj = ParseScalar(info.type_index, obj);
+    } else {
+      Array<ObjectRef> array;
+      std::string item;
+      bool failed = false;
+      uint32_t type_index = info.key->type_index;
+      for (std::istringstream is(obj); std::getline(is, item, ',');) {
+        ObjectRef parsed_obj = ParseScalar(type_index, item);
+        if (parsed_obj.defined()) {
+          array.push_back(parsed_obj);
+        } else {
+          failed = true;
+          break;
+        }
+      }
+      if (!failed) {
+        parsed_obj = std::move(array);
+      }
+    }
+    if (!parsed_obj.defined()) {
+      LOG(FATAL) << "ValueError: Cannot parse type \"" << info.type_key << "\""
+                 << ", where attribute key is \"" << name << "\""
+                 << ", and attribute is \"" << obj << "\"";
+    }
+    attrs[name] = std::move(parsed_obj);
   }
+  // set default attribute values if they do not exist
+  for (const auto& kv : key2default_) {
+    if (!attrs.count(kv.first)) {
+      attrs[kv.first] = kv.second;
+    }
+  }
+  return attrs;
 }
 
+// TODO(@junrushao1994): remove some redundant attributes
+
+TVM_REGISTER_TARGET_ID("llvm")
+    .add_attr_option<Array<String>>("keys")
+    .add_attr_option<Array<String>>("libs")
+    .add_attr_option<String>("device")
+    .add_attr_option<String>("model")
+    .add_attr_option<Bool>("system-lib")
+    .add_attr_option<String>("mcpu")
+    .add_attr_option<String>("mattr")
+    .add_attr_option<String>("mtriple")
+    .add_attr_option<String>("target")  // FIXME: rename to mtriple
+    .set_default_keys({"cpu"})
+    .set_device_type(kDLCPU);
+
+TVM_REGISTER_TARGET_ID("c")
+    .add_attr_option<Array<String>>("keys")
+    .add_attr_option<Array<String>>("libs")
+    .add_attr_option<String>("device")
+    .add_attr_option<String>("model")
+    .add_attr_option<Bool>("system-lib")
+    .set_default_keys({"cpu"})
+    .set_device_type(kDLCPU);
+
+TVM_REGISTER_TARGET_ID("micro_dev")
+    .add_attr_option<Array<String>>("keys")
+    .add_attr_option<Array<String>>("libs")
+    .add_attr_option<String>("device")
+    .add_attr_option<String>("model")
+    .add_attr_option<Bool>("system-lib")
+    .set_default_keys({"micro_dev"})
+    .set_device_type(kDLMicroDev);
+
+TVM_REGISTER_TARGET_ID("cuda")
+    .add_attr_option<Array<String>>("keys")
+    .add_attr_option<Array<String>>("libs")
+    .add_attr_option<String>("device")
+    .add_attr_option<String>("model")
+    .add_attr_option<Bool>("system-lib")
+    .add_attr_option<Integer>("max_num_threads", Integer(1024))
+    .add_attr_option<Integer>("thread_warp_size", Integer(32))
+    .add_attr_option<String>("mcpu")
+    .set_default_keys({"cuda", "gpu"})
+    .set_device_type(kDLGPU);
+
+TVM_REGISTER_TARGET_ID("nvptx")
+    .add_attr_option<Array<String>>("keys")
+    .add_attr_option<Array<String>>("libs")
+    .add_attr_option<String>("device")
+    .add_attr_option<String>("model")
+    .add_attr_option<Bool>("system-lib")
+    .add_attr_option<Integer>("max_num_threads", Integer(1024))
+    .add_attr_option<Integer>("thread_warp_size", Integer(32))
+    .add_attr_option<String>("mcpu")
+    .set_default_keys({"cuda", "gpu"})
+    .set_device_type(kDLGPU);
+
+TVM_REGISTER_TARGET_ID("rocm")
+    .add_attr_option<Array<String>>("keys")
+    .add_attr_option<Array<String>>("libs")
+    .add_attr_option<String>("device")
+    .add_attr_option<String>("model")
+    .add_attr_option<Bool>("system-lib")
+    .add_attr_option<Integer>("max_num_threads", Integer(256))
+    .add_attr_option<Integer>("thread_warp_size", Integer(64))
+    .set_default_keys({"rocm", "gpu"})
+    .set_device_type(kDLROCM);
+
+TVM_REGISTER_TARGET_ID("opencl")
+    .add_attr_option<Array<String>>("keys")
+    .add_attr_option<Array<String>>("libs")
+    .add_attr_option<String>("device")
+    .add_attr_option<String>("model")
+    .add_attr_option<Bool>("system-lib")
+    .add_attr_option<Integer>("max_num_threads", Integer(256))
+    .add_attr_option<Integer>("thread_warp_size")
+    .set_default_keys({"opencl", "gpu"})
+    .set_device_type(kDLOpenCL);
+
+TVM_REGISTER_TARGET_ID("metal")
+    .add_attr_option<Array<String>>("keys")
+    .add_attr_option<Array<String>>("libs")
+    .add_attr_option<String>("device")
+    .add_attr_option<String>("model")
+    .add_attr_option<Bool>("system-lib")
+    .add_attr_option<Integer>("max_num_threads", Integer(256))
+    .set_default_keys({"metal", "gpu"})
+    .set_device_type(kDLMetal);
+
+TVM_REGISTER_TARGET_ID("vulkan")
+    .add_attr_option<Array<String>>("keys")
+    .add_attr_option<Array<String>>("libs")
+    .add_attr_option<String>("device")
+    .add_attr_option<String>("model")
+    .add_attr_option<Bool>("system-lib")
+    .add_attr_option<Integer>("max_num_threads", Integer(256))
+    .set_default_keys({"vulkan", "gpu"})
+    .set_device_type(kDLVulkan);
+
+TVM_REGISTER_TARGET_ID("webgpu")
+    .add_attr_option<Array<String>>("keys")
+    .add_attr_option<Array<String>>("libs")
+    .add_attr_option<String>("device")
+    .add_attr_option<String>("model")
+    .add_attr_option<Bool>("system-lib")
+    .add_attr_option<Integer>("max_num_threads", Integer(256))
+    .set_default_keys({"webgpu", "gpu"})
+    .set_device_type(kDLWebGPU);
+
+TVM_REGISTER_TARGET_ID("sdaccel")
+    .add_attr_option<Array<String>>("keys")
+    .add_attr_option<Array<String>>("libs")
+    .add_attr_option<String>("device")
+    .add_attr_option<String>("model")
+    .add_attr_option<Bool>("system-lib")
+    .set_default_keys({"sdaccel", "hls"})
+    .set_device_type(kDLOpenCL);
+
+TVM_REGISTER_TARGET_ID("aocl")
+    .add_attr_option<Array<String>>("keys")
+    .add_attr_option<Array<String>>("libs")
+    .add_attr_option<String>("device")
+    .add_attr_option<String>("model")
+    .add_attr_option<Bool>("system-lib")
+    .set_default_keys({"aocl", "hls"})
+    .set_device_type(kDLAOCL);
+
+TVM_REGISTER_TARGET_ID("aocl_sw_emu")
+    .add_attr_option<Array<String>>("keys")
+    .add_attr_option<Array<String>>("libs")
+    .add_attr_option<String>("device")
+    .add_attr_option<String>("model")
+    .add_attr_option<Bool>("system-lib")
+    .set_default_keys({"aocl", "hls"})
+    .set_device_type(kDLAOCL);
+
+TVM_REGISTER_TARGET_ID("hexagon")
+    .add_attr_option<Array<String>>("keys")
+    .add_attr_option<Array<String>>("libs")
+    .add_attr_option<String>("device")
+    .add_attr_option<String>("model")
+    .add_attr_option<Bool>("system-lib")
+    .set_default_keys({"hexagon"})
+    .set_device_type(kDLHexagon);
+
+TVM_REGISTER_TARGET_ID("stackvm")
+    .add_attr_option<Array<String>>("keys")
+    .add_attr_option<Array<String>>("libs")
+    .add_attr_option<String>("device")
+    .add_attr_option<String>("model")
+    .add_attr_option<Bool>("system-lib")
+    .set_device_type(kDLCPU);
+
+TVM_REGISTER_TARGET_ID("ext_dev")
+    .add_attr_option<Array<String>>("keys")
+    .add_attr_option<Array<String>>("libs")
+    .add_attr_option<String>("device")
+    .add_attr_option<String>("model")
+    .add_attr_option<Bool>("system-lib")
+    .set_device_type(kDLExtDev);
+
+TVM_REGISTER_TARGET_ID("hybrid")
+    .add_attr_option<Array<String>>("keys")
+    .add_attr_option<Array<String>>("libs")
+    .add_attr_option<String>("device")
+    .add_attr_option<String>("model")
+    .add_attr_option<Bool>("system-lib")
+    .set_device_type(kDLCPU);
+
 }  // namespace tvm
diff --git a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc
index 75605ad..7541662 100644
--- a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc
+++ b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc
@@ -1083,7 +1083,7 @@ Stmt SchedulePostProcRewriteForTensorCore(Stmt stmt, Schedule schedule,
                                           Map<Tensor, Buffer> extern_buffer) {
   // Check if current lower target is CUDA
   auto target = tvm::Target::Current(true);
-  if (target.defined() && target->target_name != "cuda") {
+  if (target.defined() && target->id->name != "cuda") {
     return stmt;
   }
 
diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc
index 12ec270..f8a5986 100644
--- a/src/tir/analysis/verify_memory.cc
+++ b/src/tir/analysis/verify_memory.cc
@@ -177,7 +177,7 @@ bool VerifyMemory(const PrimFunc& func) {
 
   if (func->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) ==
       CallingConv::kDefault) {
-    MemoryAccessVerifier v(func, target.value()->device_type);
+    MemoryAccessVerifier v(func, target.value()->id->device_type);
     v.Run();
     return !v.Failed();
   } else {
diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc
index 154023c..f5491da 100644
--- a/src/tir/transforms/lower_custom_datatypes.cc
+++ b/src/tir/transforms/lower_custom_datatypes.cc
@@ -139,7 +139,7 @@ Pass LowerCustomDatatypes() {
     auto target = f->GetAttr<Target>(tvm::attr::kTarget);
     CHECK(target.defined()) << "LowerCustomDatatypes: Require the target attribute";
 
-    n->body = CustomDatatypesLowerer(target.value()->target_name)(std::move(n->body));
+    n->body = CustomDatatypesLowerer(target.value()->id->name)(std::move(n->body));
     return f;
   };
   return CreatePrimFuncPass(pass_func, 0, "tir.LowerCustomDatatypes", {});
diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc
index 5ec4fe3..5372ef8 100644
--- a/src/tir/transforms/lower_intrin.cc
+++ b/src/tir/transforms/lower_intrin.cc
@@ -40,12 +40,11 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
   using IRMutatorWithAnalyzer::VisitExpr_;
   using IRMutatorWithAnalyzer::VisitStmt_;
 
-  IntrinInjecter(arith::Analyzer* analyzer, std::string target_name)
-      : IRMutatorWithAnalyzer(analyzer) {
-    patterns_.push_back("tvm.intrin.rule." + target_name + ".");
+  IntrinInjecter(arith::Analyzer* analyzer, std::string target) : IRMutatorWithAnalyzer(analyzer) {
+    patterns_.push_back("tvm.intrin.rule." + target + ".");
     patterns_.push_back("tvm.intrin.rule.default.");
     fma_ = runtime::Registry::Get(patterns_[0] + "fma");
-    if (target_name == "stackvm") {
+    if (target == "stackvm") {
       support_bitwise_op_ = false;
     }
   }
@@ -275,9 +274,9 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
   bool support_bitwise_op_{true};
 };
 
-Stmt LowerIntrinStmt(Stmt stmt, const std::string target_name) {
+Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) {
   arith::Analyzer analyzer;
-  return IntrinInjecter(&analyzer, target_name)(std::move(stmt));
+  return IntrinInjecter(&analyzer, target)(std::move(stmt));
 }
 
 namespace transform {
@@ -288,7 +287,7 @@ Pass LowerIntrin() {
     auto target = f->GetAttr<Target>(tvm::attr::kTarget);
     CHECK(target.defined()) << "LowerIntrin: Require the target attribute";
     arith::Analyzer analyzer;
-    n->body = IntrinInjecter(&analyzer, target.value()->target_name)(std::move(n->body));
+    n->body = IntrinInjecter(&analyzer, target.value()->id->name)(std::move(n->body));
     return f;
   };
   return CreatePrimFuncPass(pass_func, 0, "tir.LowerIntrin", {});
diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc
index 04b8953..17b4265 100644
--- a/src/tir/transforms/lower_thread_allreduce.cc
+++ b/src/tir/transforms/lower_thread_allreduce.cc
@@ -40,7 +40,7 @@ namespace tir {
 class ThreadAllreduceBuilder final : public StmtExprMutator {
  public:
   explicit ThreadAllreduceBuilder(const TargetNode* target)
-      : target_(target), warp_size_(target->thread_warp_size) {}
+      : target_(target), warp_size_(target->GetAttr<Integer>("thread_warp_size", 1).value()) {}
 
   Stmt VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == attr::thread_extent) {
@@ -484,11 +484,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
   // Also, the warp/wavefront size differs (64 on rocm, 32 on cuda).
   bool is_warp_reduction(const std::vector<DataType>& types) const {
     // Only cuda target supports warp reductions.
-    if ((target_->target_name != "cuda") && (target_->target_name != "rocm")) return false;
+    if ((target_->id->name != "cuda") && (target_->id->name != "rocm")) return false;
 
     // rocm only supports 32 bit operands for shuffling at the moment
-    if ((target_->target_name == "rocm") &&
-        (std::any_of(types.begin(), types.end(), [](DataType ty) {
+    if ((target_->id->name == "rocm") && (std::any_of(types.begin(), types.end(), [](DataType ty) {
           if (ty.is_vector()) return true;
           return ty.bits() != 32;
         }))) {
diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc
index 480c62c..8892c32 100644
--- a/src/tir/transforms/lower_warp_memory.cc
+++ b/src/tir/transforms/lower_warp_memory.cc
@@ -392,7 +392,8 @@ Pass LowerWarpMemory() {
     auto* n = f.CopyOnWrite();
     auto target = f->GetAttr<Target>(tvm::attr::kTarget);
     CHECK(target.defined()) << "LowerWarpMemory: Require the target attribute";
-    n->body = WarpMemoryRewriter(target.value()->thread_warp_size).Rewrite(std::move(n->body));
+    int warp_size = target.value()->GetAttr<Integer>("thread_warp_size", 1).value();
+    n->body = WarpMemoryRewriter(warp_size).Rewrite(std::move(n->body));
     return f;
   };
   return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {});
diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc
index bfcf0b7..191bb0a 100644
--- a/src/tir/transforms/make_packed_api.cc
+++ b/src/tir/transforms/make_packed_api.cc
@@ -51,7 +51,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
 
   auto target = func->GetAttr<Target>(tvm::attr::kTarget);
   CHECK(target.defined()) << "MakePackedAPI: Require the target attribute";
-  int target_device_type = target.value()->device_type;
+  int target_device_type = target.value()->id->device_type;
 
   std::string name_hint = global_symbol.value();
 
diff --git a/tests/micro/test_runtime_micro_on_arm.py b/tests/micro/test_runtime_micro_on_arm.py
index 301677e..ed7d62f 100644
--- a/tests/micro/test_runtime_micro_on_arm.py
+++ b/tests/micro/test_runtime_micro_on_arm.py
@@ -33,7 +33,7 @@ from tvm.relay.testing import resnet
 # Ex : export CMSIS_ST_PATH="/home/yourid/st/STM32Cube_FW_F7_V1.16.0/Drivers/CMSIS"
 DEV_CONFIG_A = micro.device.arm.stm32f746xx.generate_config("127.0.0.1", 6666)
 DEV_CONFIG_B = micro.device.arm.stm32f746xx.generate_config("127.0.0.1", 6666)
-TARGET = 'c -device=micro_dev'
+TARGET = 'micro_dev'
 
 def relay_micro_build(func, dev_config, params=None):
     """Create a graph runtime module with a micro device context from a Relay function.
diff --git a/tests/python/unittest/test_runtime_micro.py b/tests/python/unittest/test_runtime_micro.py
index 2eea3df..eb137a9 100644
--- a/tests/python/unittest/test_runtime_micro.py
+++ b/tests/python/unittest/test_runtime_micro.py
@@ -28,7 +28,7 @@ from tvm.relay.testing import resnet
 # # Use the host emulated micro device.
 DEV_CONFIG_A = micro.device.host.generate_config()
 DEV_CONFIG_B = micro.device.host.generate_config()
-TARGET = 'c -device=micro_dev'
+TARGET = 'micro_dev'
 
 def relay_micro_build(func, dev_config, params=None):
     """Create a graph runtime module with a micro device context from a Relay function.
diff --git a/tests/python/unittest/test_target_target.py b/tests/python/unittest/test_target_target.py
index da7bcee..fe3799b 100644
--- a/tests/python/unittest/test_target_target.py
+++ b/tests/python/unittest/test_target_target.py
@@ -57,8 +57,8 @@ def test_target_dispatch():
 def test_target_string_parse():
     target = tvm.target.create("cuda -model=unknown -libs=cublas,cudnn")
 
-    assert target.target_name == "cuda"
-    assert target.options == ['-model=unknown', '-libs=cublas,cudnn']
+    assert target.id.name == "cuda"
+    assert target.model == "unknown"
     assert target.keys == ['cuda', 'gpu']
     assert target.libs == ['cublas', 'cudnn']
     assert str(target) == str(tvm.target.cuda(options="-libs=cublas,cudnn"))
diff --git a/topi/include/topi/cuda/dense.h b/topi/include/topi/cuda/dense.h
index 145d249..c8ceebf 100644
--- a/topi/include/topi/cuda/dense.h
+++ b/topi/include/topi/cuda/dense.h
@@ -62,7 +62,7 @@ inline tvm::te::Tensor dense_cuda(const Target& target, const tvm::te::Tensor& d
   auto in_dim = data->shape[1];
   auto out_dim = weight->shape[0];
 
-  if (target->libs().count("cublas")) {
+  if (target->GetLibs().count("cublas")) {
     CHECK_EQ(data->dtype, out_dtype) << "Mixed precision not supported.";
     auto mm = topi::contrib::cublas_matmul(data, weight, false, true);
     if (bias.defined()) {
@@ -85,7 +85,7 @@ inline tvm::te::Tensor dense_cuda(const Target& target, const tvm::te::Tensor& d
  * \return A schedule for the given ops.
  */
 inline Schedule schedule_dense(const Target& target, const Array<Tensor>& outs) {
-  if (target->target_name == "cuda" && target->libs().count("cublas")) {
+  if (target->id->name == "cuda" && target->GetLibs().count("cublas")) {
     return topi::generic::schedule_extern(target, outs);
   }
 
diff --git a/topi/include/topi/cuda/injective.h b/topi/include/topi/cuda/injective.h
index 5a5c5af..e7bce05 100644
--- a/topi/include/topi/cuda/injective.h
+++ b/topi/include/topi/cuda/injective.h
@@ -47,7 +47,7 @@ namespace cuda {
 inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out) {
   auto fused = detail::Fuse(sch[out], sch[out]->op.as<ComputeOpNode>()->axis);
   auto target = Target::Current(false);
-  auto num_thread = target->max_num_threads;
+  int num_thread = target->GetAttr<Integer>("max_num_threads").value();
   IterVar bx, tx;
   sch[out].split(fused, num_thread, &bx, &tx);
   sch[out].bind(bx, thread_axis(Range(), "blockIdx.x"));
diff --git a/topi/include/topi/cuda/pooling.h b/topi/include/topi/cuda/pooling.h
index 87866f2..7e8f55d 100644
--- a/topi/include/topi/cuda/pooling.h
+++ b/topi/include/topi/cuda/pooling.h
@@ -56,7 +56,7 @@ inline Schedule schedule_pool(const Target& target, const Array<Tensor>& outs) {
     if (padded_input->op->IsInstance<ComputeOpNode>()) {
       s[padded_input].compute_inline();
     }
-    auto num_thread = target->max_num_threads;
+    int num_thread = target->GetAttr<Integer>("max_num_threads").value();
     Tensor out;
     Tensor OL;
     if (detail::contains(s->outputs, pool->op)) {
diff --git a/topi/include/topi/cuda/reduction.h b/topi/include/topi/cuda/reduction.h
index 35ce346..377b922 100644
--- a/topi/include/topi/cuda/reduction.h
+++ b/topi/include/topi/cuda/reduction.h
@@ -69,7 +69,7 @@ Schedule ScheduleReduce(const Target& target, Operation op, Schedule sch,
   if (out_stage->op.as<ComputeOpNode>()->axis.size() > 0) {
     all_reduce = false;
     num_thread = 32;
-    if (target->target_name == "opencl") {
+    if (target->id->name == "opencl") {
       // Without this, CL_INVALID_WORK_GROUP_SIZE occurs with python tests.
       // Don't know why.
       num_thread = 16;
@@ -79,7 +79,7 @@ Schedule ScheduleReduce(const Target& target, Operation op, Schedule sch,
     thread_y = tvm::te::thread_axis(Range(0, num_thread), "threadIdx.y");
   } else {
     all_reduce = true;
-    num_thread = target->max_num_threads;
+    num_thread = target->GetAttr<Integer>("max_num_threads").value();
     thread_x = tvm::te::thread_axis(Range(0, num_thread), "threadIdx.x");
   }
 
diff --git a/topi/include/topi/rocm/dense.h b/topi/include/topi/rocm/dense.h
index 72f8ee6..e2e04b4 100644
--- a/topi/include/topi/rocm/dense.h
+++ b/topi/include/topi/rocm/dense.h
@@ -63,7 +63,7 @@ inline tvm::te::Tensor dense_rocm(const Target& target, const tvm::te::Tensor& d
   auto in_dim = data->shape[1];
   auto out_dim = weight->shape[0];
 
-  if (target->libs().count("rocblas")) {
+  if (target->GetLibs().count("rocblas")) {
     CHECK_EQ(data->dtype, out_dtype) << "Mixed precision not supported.";
     auto mm = topi::contrib::rocblas_matmul(data, weight, false, true);
     if (bias.defined()) {
@@ -86,7 +86,7 @@ inline tvm::te::Tensor dense_rocm(const Target& target, const tvm::te::Tensor& d
  * \return A schedule for the given ops.
  */
 inline Schedule schedule_dense(const Target& target, const Array<Tensor>& outs) {
-  if (target->target_name == "rocm" && target->libs().count("rocblas")) {
+  if (target->id->name == "rocm" && target->GetLibs().count("rocblas")) {
     return topi::generic::schedule_extern(target, outs);
   }
 
diff --git a/topi/python/topi/arm_cpu/conv2d_gemm.py b/topi/python/topi/arm_cpu/conv2d_gemm.py
index c1587ba..68161c3 100644
--- a/topi/python/topi/arm_cpu/conv2d_gemm.py
+++ b/topi/python/topi/arm_cpu/conv2d_gemm.py
@@ -27,7 +27,7 @@ from .tensor_intrin import gemv_quantized, gemv_quantized_impl
 def is_aarch64_arm():
     """ Checks whether we are compiling for an AArch64 target. """
     target = tvm.target.Target.current(allow_none=False)
-    return 'aarch64' in ' '.join(target.options)
+    return 'aarch64' in target.attrs.get("target", "")
 
 
 # Compute function
diff --git a/topi/python/topi/cuda/batch_matmul.py b/topi/python/topi/cuda/batch_matmul.py
index 7d92edf..bcd98cc 100644
--- a/topi/python/topi/cuda/batch_matmul.py
+++ b/topi/python/topi/cuda/batch_matmul.py
@@ -69,7 +69,7 @@ def schedule_batch_matmul(cfg, outs):
         cfg.define_split("tile_k", k, num_outputs=2)
         cfg.define_knob("auto_unroll_max_step", [8, 16, 32, 64])
         target = tvm.target.Target.current()
-        if target.target_name in ['nvptx', 'rocm']:
+        if target.id.name in ['nvptx', 'rocm']:
             # llvm-based backends cannot do non-explicit unrolling
             cfg.define_knob("unroll_explicit", [1])
         else:
diff --git a/topi/python/topi/cuda/conv1d.py b/topi/python/topi/cuda/conv1d.py
index 3ddecbe..533cf74 100644
--- a/topi/python/topi/cuda/conv1d.py
+++ b/topi/python/topi/cuda/conv1d.py
@@ -72,7 +72,7 @@ def schedule_conv1d_ncw(cfg, outs):
             cfg.define_knob("auto_unroll_max_step", [64, 512, 1500])
 
             target = tvm.target.Target.current()
-            if target.target_name in ['nvptx', 'rocm']:
+            if target.id.name in ['nvptx', 'rocm']:
                 cfg.define_knob("unroll_explicit", [1])
             else:
                 cfg.define_knob("unroll_explicit", [0, 1])
@@ -197,7 +197,7 @@ def schedule_conv1d_nwc(cfg, outs):
             cfg.define_knob("auto_unroll_max_step", [64, 512, 1500])
 
             target = tvm.target.Target.current()
-            if target.target_name in ['nvptx', 'rocm']:
+            if target.id.name in ['nvptx', 'rocm']:
                 cfg.define_knob("unroll_explicit", [1])
             else:
                 cfg.define_knob("unroll_explicit", [0, 1])
diff --git a/topi/python/topi/cuda/conv1d_transpose_ncw.py b/topi/python/topi/cuda/conv1d_transpose_ncw.py
index a2ac7e1..ffce584 100644
--- a/topi/python/topi/cuda/conv1d_transpose_ncw.py
+++ b/topi/python/topi/cuda/conv1d_transpose_ncw.py
@@ -124,7 +124,7 @@ def schedule_conv1d_transpose_ncw(cfg, outs):
             cfg.define_knob("auto_unroll_max_step", [64, 512, 1500])
 
             target = tvm.target.Target.current()
-            if target.target_name in ['nvptx', 'rocm']:
+            if target.id.name in ['nvptx', 'rocm']:
                 cfg.define_knob("unroll_explicit", [1])
             else:
                 cfg.define_knob("unroll_explicit", [0, 1])
diff --git a/topi/python/topi/cuda/conv2d_direct.py b/topi/python/topi/cuda/conv2d_direct.py
index db6bff2..9d8146e 100644
--- a/topi/python/topi/cuda/conv2d_direct.py
+++ b/topi/python/topi/cuda/conv2d_direct.py
@@ -36,7 +36,7 @@ def schedule_direct_cuda(cfg, s, conv):
     cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
 
     target = tvm.target.Target.current()
-    if target.target_name in ['nvptx', 'rocm']:
+    if target.id.name in ['nvptx', 'rocm']:
         cfg.define_knob("unroll_explicit", [1])
     else:
         cfg.define_knob("unroll_explicit", [0, 1])
@@ -44,7 +44,7 @@ def schedule_direct_cuda(cfg, s, conv):
     # fallback support
     if cfg.is_fallback:
         ref_log = autotvm.tophub.load_reference_log(
-            target.target_name, target.model, 'conv2d_nchw.cuda')
+            target.id.name, target.model, 'conv2d_nchw.cuda')
         cfg.fallback_with_reference_log(ref_log)
     ##### space definition end #####
 
diff --git a/topi/python/topi/cuda/conv2d_nhwc.py b/topi/python/topi/cuda/conv2d_nhwc.py
index 55714b2..c7c3f18 100644
--- a/topi/python/topi/cuda/conv2d_nhwc.py
+++ b/topi/python/topi/cuda/conv2d_nhwc.py
@@ -56,7 +56,7 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv):
     target = tvm.target.Target.current()
     if cfg.is_fallback:
         ref_log = autotvm.tophub.load_reference_log(
-            target.target_name, target.model, 'conv2d_nhwc.cuda')
+            target.id.name, target.model, 'conv2d_nhwc.cuda')
         cfg.fallback_with_reference_log(ref_log)
 
     tile_n = cfg["tile_n"].val
diff --git a/topi/python/topi/cuda/conv2d_nhwc_tensorcore.py b/topi/python/topi/cuda/conv2d_nhwc_tensorcore.py
index 790db0f..7703e40 100644
--- a/topi/python/topi/cuda/conv2d_nhwc_tensorcore.py
+++ b/topi/python/topi/cuda/conv2d_nhwc_tensorcore.py
@@ -134,7 +134,7 @@ def schedule_nhwc_tensorcore_cuda(cfg, s, Conv):
     target = tvm.target.Target.current()
     if cfg.is_fallback:
         ref_log = autotvm.tophub.load_reference_log(
-            target.target_name, target.model, 'conv2d_nhwc_tensorcore.cuda')
+            target.id.name, target.model, 'conv2d_nhwc_tensorcore.cuda')
         cfg.fallback_with_reference_log(ref_log)
 
     block_row_warps = cfg["block_row_warps"].val
diff --git a/topi/python/topi/cuda/conv2d_transpose_nchw.py b/topi/python/topi/cuda/conv2d_transpose_nchw.py
index 5ad4947..4dfcc03 100644
--- a/topi/python/topi/cuda/conv2d_transpose_nchw.py
+++ b/topi/python/topi/cuda/conv2d_transpose_nchw.py
@@ -177,7 +177,7 @@ def schedule_conv2d_transpose_nchw(cfg, outs):
             cfg.define_knob("auto_unroll_max_step", [64, 512, 1500])
 
             target = tvm.target.Target.current()
-            if target.target_name in ['nvptx', 'rocm']:
+            if target.id.name in ['nvptx', 'rocm']:
                 cfg.define_knob("unroll_explicit", [1])
             else:
                 cfg.define_knob("unroll_explicit", [0, 1])
diff --git a/topi/python/topi/cuda/conv2d_winograd.py b/topi/python/topi/cuda/conv2d_winograd.py
index 881f63a..d976aaa 100644
--- a/topi/python/topi/cuda/conv2d_winograd.py
+++ b/topi/python/topi/cuda/conv2d_winograd.py
@@ -193,7 +193,7 @@ def schedule_winograd_cuda(cfg, s, output, pre_computed):
     cfg.define_split("tile_rc", rc, num_outputs=2)
     cfg.define_knob("auto_unroll_max_step", [0, 128, 1500])
     target = tvm.target.Target.current()
-    if target.target_name in ['nvptx', 'rocm']:
+    if target.id.name in ['nvptx', 'rocm']:
         cfg.define_knob("unroll_explicit", [1])
     else:
         cfg.define_knob("unroll_explicit", [0, 1])
diff --git a/topi/python/topi/cuda/conv3d_direct.py b/topi/python/topi/cuda/conv3d_direct.py
index 50b73d6..0b80e79 100644
--- a/topi/python/topi/cuda/conv3d_direct.py
+++ b/topi/python/topi/cuda/conv3d_direct.py
@@ -43,7 +43,7 @@ def schedule_direct_conv3d_cuda(cfg, s, conv, layout, workload_name):
     cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
 
     target = tvm.target.Target.current()
-    if target.target_name in ['nvptx', 'rocm']:
+    if target.id.name in ['nvptx', 'rocm']:
         cfg.define_knob("unroll_explicit", [1])
     else:
         cfg.define_knob("unroll_explicit", [0, 1])
@@ -51,7 +51,7 @@ def schedule_direct_conv3d_cuda(cfg, s, conv, layout, workload_name):
     # fallback support
     if cfg.is_fallback:
         ref_log = autotvm.tophub.load_reference_log(
-            target.target_name, target.model, workload_name)
+            target.id.name, target.model, workload_name)
         cfg.fallback_with_reference_log(ref_log)
     ##### space definition end #####
 
diff --git a/topi/python/topi/cuda/conv3d_ndhwc_tensorcore.py b/topi/python/topi/cuda/conv3d_ndhwc_tensorcore.py
index e3c7513..68b0145 100644
--- a/topi/python/topi/cuda/conv3d_ndhwc_tensorcore.py
+++ b/topi/python/topi/cuda/conv3d_ndhwc_tensorcore.py
@@ -141,7 +141,7 @@ def schedule_ndhwc_tensorcore_cuda(cfg, s, Conv):
     target = tvm.target.Target.current()
     if cfg.is_fallback:
         ref_log = autotvm.tophub.load_reference_log(
-            target.target_name, target.model, 'conv3d_ndhwc_tensorcore.cuda')
+            target.id.name, target.model, 'conv3d_ndhwc_tensorcore.cuda')
         cfg.fallback_with_reference_log(ref_log)
 
     block_row_warps = cfg["block_row_warps"].val
diff --git a/topi/python/topi/cuda/conv3d_winograd.py b/topi/python/topi/cuda/conv3d_winograd.py
index 5876243..e8b5037 100644
--- a/topi/python/topi/cuda/conv3d_winograd.py
+++ b/topi/python/topi/cuda/conv3d_winograd.py
@@ -321,7 +321,7 @@ def schedule_winograd_cuda(cfg, s, output, pre_computed):
     cfg.define_split("tile_rc", rc, num_outputs=2)
     cfg.define_knob("auto_unroll_max_step", [0, 128, 1500])
     target = tvm.target.Target.current()
-    if target.target_name in ['nvptx', 'rocm']:
+    if target.id.name in ['nvptx', 'rocm']:
         cfg.define_knob("unroll_explicit", [1])
     else:
         cfg.define_knob("unroll_explicit", [0, 1])
@@ -478,7 +478,7 @@ def schedule_winograd_no_depth_cuda(cfg, s, output, pre_computed):
     cfg.define_split("tile_rz", rz, num_outputs=2)
     cfg.define_knob("auto_unroll_max_step", [0, 128, 1500])
     target = tvm.target.Target.current()
-    if target.target_name in ['nvptx', 'rocm']:
+    if target.id.name in ['nvptx', 'rocm']:
         cfg.define_knob("unroll_explicit", [1])
     else:
         cfg.define_knob("unroll_explicit", [0, 1])
diff --git a/topi/python/topi/cuda/correlation.py b/topi/python/topi/cuda/correlation.py
index a383e4e..6d9be95 100644
--- a/topi/python/topi/cuda/correlation.py
+++ b/topi/python/topi/cuda/correlation.py
@@ -81,7 +81,7 @@ def _schedule_correlation_nchw(cfg, s, correlation):
     cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
 
     target = tvm.target.Target.current()
-    if target.target_name in ['nvptx', 'rocm']:
+    if target.id.name in ['nvptx', 'rocm']:
         cfg.define_knob("unroll_explicit", [1])
     else:
         cfg.define_knob("unroll_explicit", [0, 1])
diff --git a/topi/python/topi/cuda/deformable_conv2d.py b/topi/python/topi/cuda/deformable_conv2d.py
index 8c31835..6def731 100644
--- a/topi/python/topi/cuda/deformable_conv2d.py
+++ b/topi/python/topi/cuda/deformable_conv2d.py
@@ -71,7 +71,7 @@ def _schedule_direct_cuda(cfg, s, conv):
     cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
 
     target = tvm.target.Target.current()
-    if target.target_name in ['nvptx', 'rocm']:
+    if target.id.name in ['nvptx', 'rocm']:
         cfg.define_knob("unroll_explicit", [1])
     else:
         cfg.define_knob("unroll_explicit", [0, 1])
diff --git a/topi/python/topi/cuda/dense_tensorcore.py b/topi/python/topi/cuda/dense_tensorcore.py
index 3546847..a6d1c05 100644
--- a/topi/python/topi/cuda/dense_tensorcore.py
+++ b/topi/python/topi/cuda/dense_tensorcore.py
@@ -95,7 +95,7 @@ def _schedule_dense_tensorcore(cfg, s, C):
     target = tvm.target.Target.current()
     if cfg.is_fallback:
         ref_log = autotvm.tophub.load_reference_log(
-            target.target_name, target.model, 'dense_tensorcore.cuda')
+            target.id.name, target.model, 'dense_tensorcore.cuda')
         cfg.fallback_with_reference_log(ref_log)
 
     # Deal with op fusion, such as bias and relu
diff --git a/topi/python/topi/cuda/depthwise_conv2d.py b/topi/python/topi/cuda/depthwise_conv2d.py
index b7cb32d..f9ef8b6 100644
--- a/topi/python/topi/cuda/depthwise_conv2d.py
+++ b/topi/python/topi/cuda/depthwise_conv2d.py
@@ -61,7 +61,7 @@ def schedule_depthwise_conv2d_nchw(cfg, outs):
             cfg.define_knob("auto_unroll_max_step", [0, 256, 1500])
 
             target = tvm.target.Target.current()
-            if target.target_name in ['nvptx', 'rocm']:
+            if target.id.name in ['nvptx', 'rocm']:
                 cfg.define_knob("unroll_explicit", [1])
             else:
                 cfg.define_knob("unroll_explicit", [0, 1])
@@ -69,7 +69,7 @@ def schedule_depthwise_conv2d_nchw(cfg, outs):
             # fallback support
             if cfg.is_fallback:
                 ref_log = autotvm.tophub.load_reference_log(
-                    target.target_name, target.model, 'depthwise_conv2d_nchw.cuda')
+                    target.id.name, target.model, 'depthwise_conv2d_nchw.cuda')
                 cfg.fallback_with_reference_log(ref_log)
                 # TODO(lmzheng): A bug here, set unroll_explicit to False as workaround
                 cfg['unroll_explicit'].val = 0
@@ -169,7 +169,7 @@ def schedule_depthwise_conv2d_nhwc(outs):
         # num_thread here could be 728, it is larger than cuda.max_num_threads
         num_thread = tvm.arith.Analyzer().simplify(temp.shape[3]).value
         target = tvm.target.Target.current()
-        if target and (target.target_name not in ["cuda", "nvptx"]):
+        if target and (target.id.name not in ["cuda", "nvptx"]):
             num_thread = target.max_num_threads
         xoc, xic = s[Output].split(c, factor=num_thread)
         s[Output].reorder(xoc, b, h, w, xic)
diff --git a/topi/python/topi/cuda/group_conv2d_nchw.py b/topi/python/topi/cuda/group_conv2d_nchw.py
index c5cf72b..e5cbe3e 100644
--- a/topi/python/topi/cuda/group_conv2d_nchw.py
+++ b/topi/python/topi/cuda/group_conv2d_nchw.py
@@ -83,7 +83,7 @@ def _schedule_group_conv2d_nchw_direct(cfg, s, conv):
     cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
 
     target = tvm.target.Target.current()
-    if target.target_name in ['nvptx', 'rocm']:
+    if target.id.name in ['nvptx', 'rocm']:
         cfg.define_knob("unroll_explicit", [1])
     else:
         cfg.define_knob("unroll_explicit", [0, 1])
diff --git a/topi/python/topi/cuda/reduction.py b/topi/python/topi/cuda/reduction.py
index d885c09..9d3c529 100644
--- a/topi/python/topi/cuda/reduction.py
+++ b/topi/python/topi/cuda/reduction.py
@@ -36,7 +36,7 @@ def _schedule_reduce(op, sch, is_idx_reduce=False):
         all_reduce = False
         num_thread = 32
         target = tvm.target.Target.current()
-        if target and target.target_name == "opencl":
+        if target and target.id.name == "opencl":
             # without it, CL_INVALID_WORK_GROUP_SIZE occurred when running test_topi_reduce.py
             # don't know why
             num_thread = 16
diff --git a/topi/python/topi/cuda/softmax.py b/topi/python/topi/cuda/softmax.py
index 5f7402b..910d0f3 100644
--- a/topi/python/topi/cuda/softmax.py
+++ b/topi/python/topi/cuda/softmax.py
@@ -59,9 +59,9 @@ def schedule_softmax(outs):
     #
     # TODO(tvm-team) Fix nvptx codegen or deprecate nvptx backend.
     def sched_warp_softmax():
-        if tgt.target_name == "nvptx" or tgt.target_name == "rocm":
+        if tgt.id.name == "nvptx" or tgt.id.name == "rocm":
             return softmax.dtype == "float32" or softmax.dtype == "int32"
-        if tgt.target_name != "cuda":
+        if tgt.id.name != "cuda":
             # this is used as the gpu schedule for other arches which may not have warp reductions
             return False
         return True
diff --git a/topi/python/topi/cuda/vision.py b/topi/python/topi/cuda/vision.py
index eb49328..c5e2a6e 100644
--- a/topi/python/topi/cuda/vision.py
+++ b/topi/python/topi/cuda/vision.py
@@ -53,7 +53,7 @@ def schedule_reorg(outs):
         The computation schedule for reorg.
     """
     target = tvm.target.Target.current(allow_none=False)
-    cpp_target = cpp.TEST_create_target(target.target_name)
+    cpp_target = cpp.TEST_create_target(target.id.name)
     return cpp.cuda.schedule_injective(cpp_target, outs)
 
 def schedule_nms(outs):
diff --git a/topi/python/topi/generic/default.py b/topi/python/topi/generic/default.py
index 59e5a25..93a1dd2 100644
--- a/topi/python/topi/generic/default.py
+++ b/topi/python/topi/generic/default.py
@@ -24,7 +24,7 @@ def default_schedule(outs, auto_inline):
     """Default schedule for llvm."""
     target = tvm.target.Target.current(allow_none=False)
     outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
-    if target.target_name not in ("llvm", "c"):
+    if target.id.name not in ("llvm", "c"):
         raise RuntimeError("schedule not registered for '%s'" % target)
     s = te.create_schedule([x.op for x in outs])
     if auto_inline:
diff --git a/topi/python/topi/generic/injective.py b/topi/python/topi/generic/injective.py
index fa6aee4..a60b1e7 100644
--- a/topi/python/topi/generic/injective.py
+++ b/topi/python/topi/generic/injective.py
@@ -54,7 +54,7 @@ def schedule_injective(outs):
         The computation schedule for the op.
     """
     target = tvm.target.Target.current(allow_none=False)
-    if target.target_name != "llvm":
+    if target.id.name != "llvm":
         raise RuntimeError("schedule_injective not registered for '%s'" % target)
     outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
     x = outs[0]
diff --git a/topi/python/topi/generic/vision.py b/topi/python/topi/generic/vision.py
index edf1a48..a1db9ab 100644
--- a/topi/python/topi/generic/vision.py
+++ b/topi/python/topi/generic/vision.py
@@ -37,7 +37,7 @@ def schedule_reorg(outs):
       The computation schedule for the op.
     """
     target = tvm.target.Target.current(allow_none=False)
-    cpp_target = cpp.TEST_create_target(target.target_name)
+    cpp_target = cpp.TEST_create_target(target.id.name)
     return cpp.generic.default_schedule(cpp_target, outs, False)
 
 def schedule_get_valid_counts(outs):
diff --git a/topi/python/topi/intel_graphics/depthwise_conv2d.py b/topi/python/topi/intel_graphics/depthwise_conv2d.py
index 6508099..bc2b27b 100644
--- a/topi/python/topi/intel_graphics/depthwise_conv2d.py
+++ b/topi/python/topi/intel_graphics/depthwise_conv2d.py
@@ -62,7 +62,7 @@ def schedule_depthwise_conv2d_nchw(cfg, outs):
             cfg.define_knob("auto_unroll_max_step", [0, 256, 1500])
 
             target = tvm.target.Target.current()
-            if target.target_name in ['nvptx', 'rocm']:
+            if target.id.name in ['nvptx', 'rocm']:
                 cfg.define_knob("unroll_explicit", [1])
             else:
                 cfg.define_knob("unroll_explicit", [0, 1])
@@ -70,7 +70,7 @@ def schedule_depthwise_conv2d_nchw(cfg, outs):
             # fallback support
             if cfg.is_fallback:
                 ref_log = autotvm.tophub.load_reference_log(
-                    target.target_name, target.model, 'depthwise_conv2d_nchw.intel_graphics')
+                    target.id.name, target.model, 'depthwise_conv2d_nchw.intel_graphics')
                 cfg.fallback_with_reference_log(ref_log)
                 cfg['unroll_explicit'].val = 0
             ##### space definition end #####
@@ -170,7 +170,7 @@ def schedule_depthwise_conv2d_nhwc(outs):
         # num_thread here could be 728, it is larger than cuda.max_num_threads
         num_thread = tvm.arith.Analyzer().simplify(temp.shape[3]).value
         target = tvm.target.Target.current()
-        if target and (target.target_name not in ["cuda", "nvptx"]):
+        if target and (target.id.name not in ["cuda", "nvptx"]):
             num_thread = target.max_num_threads
         xoc, xic = s[Output].split(c, factor=num_thread)
         s[Output].reorder(xoc, b, h, w, xic)