You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by kp...@apache.org on 2022/07/18 19:37:35 UTC

[tvm] branch main updated: [Target] Add target_parser to TargetKind (#12119)

This is an automated email from the ASF dual-hosted git repository.

kparzysz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d4309cf810 [Target] Add target_parser to TargetKind (#12119)
d4309cf810 is described below

commit d4309cf81003c5fdeeb75583ac96cc0926a22b25
Author: Christopher Sidebottom <ch...@arm.com>
AuthorDate: Mon Jul 18 20:37:20 2022 +0100

    [Target] Add target_parser to TargetKind (#12119)
    
    This adds the `target_parser` as described in https://github.com/apache/tvm-rfcs/pull/71, which parses an incoming `TargetJSON` and produces a new configuration for generating the final `Target` object from.
    
    Marks `set_attrs_preprocessor` as deprecated and errors if both `set_attrs_preprocessor` and `set_target_parser` exist together.
---
 docs/arch/device_target_interactions.rst |  4 +--
 include/tvm/target/target_kind.h         | 23 ++++++++++++++
 src/target/target.cc                     | 17 +++++++++--
 src/target/target_kind.cc                | 52 ++++++++++++++++----------------
 tests/cpp/target_test.cc                 | 47 +++++++++++++++++++++++++++++
 5 files changed, 112 insertions(+), 31 deletions(-)

diff --git a/docs/arch/device_target_interactions.rst b/docs/arch/device_target_interactions.rst
index 9c391d31be..ec8d52226e 100644
--- a/docs/arch/device_target_interactions.rst
+++ b/docs/arch/device_target_interactions.rst
@@ -194,8 +194,8 @@ different code generation targets can run on the same physical device.
 device type.)
 
 All options for a specific target kind are added with the
-``add_attr_option`` function, with optional default values.  A
-preprocessor can be added with ``set_attrs_preprocessor`` to define
+``add_attr_option`` function, with optional default values.  A `Target`
+parser can be added with ``set_target_parser`` to process
 any parameters that are dynamically based on other parameters or
 queried from device properties.
 
diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h
index 4879470e76..e20f8547af 100644
--- a/include/tvm/target/target_kind.h
+++ b/include/tvm/target/target_kind.h
@@ -37,6 +37,16 @@ namespace tvm {
 
 class Target;
 
+/*!
+ * \brief TargetParser to apply on instantiation of a given TargetKind
+ *
+ * \param target_json Target in JSON format to be transformed during parsing.
+ *
+ * \return The transformed Target JSON object.
+ */
+using TargetJSON = Map<String, ObjectRef>;
+using FTVMTargetParser = TypedPackedFunc<TargetJSON(TargetJSON)>;
+
 /*!
  * \brief RelayToTIR tvm::transform::Pass specific to a TargetKind
  *
@@ -82,6 +92,8 @@ class TargetKindNode : public Object {
   Array<String> default_keys;
   /*! \brief Function used to preprocess on target creation */
   PackedFunc preprocessor;
+  /*! \brief Function used to parse a JSON target during creation */
+  FTVMTargetParser target_parser;
 
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("name", &name);
@@ -207,6 +219,11 @@ class TargetKindRegEntry {
    */
   template <typename FLambda>
   inline TargetKindRegEntry& set_attrs_preprocessor(FLambda f);
+  /*!
+   * \brief Set the parsing function applied upon target creation
+   * \param parser The Target parsing function
+   */
+  inline TargetKindRegEntry& set_target_parser(FTVMTargetParser parser);
   /*!
    * \brief Register a valid configuration option and its ValueType for validation
    * \param key The configuration key
@@ -353,11 +370,17 @@ inline TargetKindRegEntry& TargetKindRegEntry::set_default_keys(std::vector<Stri
 
 template <typename FLambda>
 inline TargetKindRegEntry& TargetKindRegEntry::set_attrs_preprocessor(FLambda f) {
+  LOG(WARNING) << "set_attrs_preprocessor is deprecated please use set_target_parser instead";
   using FType = typename tvm::runtime::detail::function_signature<FLambda>::FType;
   kind_->preprocessor = tvm::runtime::TypedPackedFunc<FType>(std::move(f)).packed();
   return *this;
 }
 
+inline TargetKindRegEntry& TargetKindRegEntry::set_target_parser(FTVMTargetParser parser) {
+  kind_->target_parser = parser;
+  return *this;
+}
+
 template <typename ValueType>
 inline TargetKindRegEntry& TargetKindRegEntry::add_attr_option(const String& key) {
   ICHECK(!kind_->key2vtype_.count(key))
diff --git a/src/target/target.cc b/src/target/target.cc
index 01f9bfaeec..207a399a77 100644
--- a/src/target/target.cc
+++ b/src/target/target.cc
@@ -52,7 +52,7 @@ class TargetInternal {
   static ObjectPtr<Object> FromString(const String& tag_or_config_or_target_str);
   static ObjectPtr<Object> FromConfigString(const String& config_str);
   static ObjectPtr<Object> FromRawString(const String& target_str);
-  static ObjectPtr<Object> FromConfig(std::unordered_map<String, ObjectRef> config);
+  static ObjectPtr<Object> FromConfig(Map<String, ObjectRef> config);
   static void ConstructorDispatcher(TVMArgs args, TVMRetValue* rv);
   static Target WithHost(const Target& target, const Target& target_host) {
     ObjectPtr<TargetNode> n = make_object<TargetNode>(*target.get());
@@ -716,17 +716,27 @@ ObjectPtr<Object> TargetInternal::FromRawString(const String& target_str) {
   return TargetInternal::FromConfig(config);
 }
 
-ObjectPtr<Object> TargetInternal::FromConfig(std::unordered_map<String, ObjectRef> config) {
+ObjectPtr<Object> TargetInternal::FromConfig(Map<String, ObjectRef> config) {
   const String kKind = "kind";
   const String kTag = "tag";
   const String kKeys = "keys";
   const String kDeviceName = "device";
   const String kHost = "host";
   ObjectPtr<TargetNode> target = make_object<TargetNode>();
+
   // parse 'kind'
   if (config.count(kKind)) {
     if (const auto* kind = config[kKind].as<StringObj>()) {
       target->kind = GetTargetKind(GetRef<String>(kind));
+      ICHECK(!(target->kind->preprocessor != nullptr && target->kind->target_parser != nullptr))
+          << "Cannot use both set_attrs_preprocessor and set_target_parser";
+
+      // Run JSON Parser over JSON input
+      if (target->kind->target_parser != nullptr) {
+        VLOG(9) << "TargetInternal::FromConfig - Running target_parser";
+        config = target->kind->target_parser(config);
+      }
+
       config.erase(kKind);
     } else {
       throw Error(": Expect type of field \"kind\" is String, but get type: " +
@@ -828,8 +838,9 @@ ObjectPtr<Object> TargetInternal::FromConfig(std::unordered_map<String, ObjectRe
   } else {
     target->attrs = attrs;
   }
+
   return target;
-}
+}  // namespace tvm
 
 std::unordered_map<String, ObjectRef> TargetInternal::QueryDevice(int device_id,
                                                                   const TargetNode* target) {
diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc
index 1148013706..7620c6fc2e 100644
--- a/src/target/target_kind.cc
+++ b/src/target/target_kind.cc
@@ -145,15 +145,15 @@ void CheckOrSetAttr(Map<String, ObjectRef>* attrs, const String& name, const Str
 
 /*!
  * \brief Update the attributes in the CUDA target.
- * \param attrs The original attributes
+ * \param target The Target to update
  * \return The updated attributes
  */
-Map<String, ObjectRef> UpdateCUDAAttrs(Map<String, ObjectRef> attrs) {
+TargetJSON UpdateCUDAAttrs(TargetJSON target) {
   // Update -arch=sm_xx
   int archInt;
-  if (attrs.count("arch")) {
+  if (target.count("arch")) {
     // If -arch has been specified, validate the correctness
-    String archStr = Downcast<String>(attrs.at("arch"));
+    String archStr = Downcast<String>(target.at("arch"));
     archInt = ExtractIntWithPrefix(archStr, "sm_");
     ICHECK(archInt != -1) << "ValueError: CUDA target gets an invalid CUDA arch: -arch=" << archStr;
   } else {
@@ -165,23 +165,23 @@ Map<String, ObjectRef> UpdateCUDAAttrs(Map<String, ObjectRef> attrs) {
     } else {
       archInt = std::stod(version.operator std::string()) * 10 + 0.1;
     }
-    attrs.Set("arch", String("sm_") + std::to_string(archInt));
+    target.Set("arch", String("sm_") + std::to_string(archInt));
   }
-  return attrs;
+  return target;
 }
 
 /*!
  * \brief Update the attributes in the LLVM NVPTX target.
- * \param attrs The original attributes
+ * \param target The Target to update
  * \return The updated attributes
  */
-Map<String, ObjectRef> UpdateNVPTXAttrs(Map<String, ObjectRef> attrs) {
-  CheckOrSetAttr(&attrs, "mtriple", "nvptx64-nvidia-cuda");
+TargetJSON UpdateNVPTXAttrs(TargetJSON target) {
+  CheckOrSetAttr(&target, "mtriple", "nvptx64-nvidia-cuda");
   // Update -mcpu=sm_xx
   int arch;
-  if (attrs.count("mcpu")) {
+  if (target.count("mcpu")) {
     // If -mcpu has been specified, validate the correctness
-    String mcpu = Downcast<String>(attrs.at("mcpu"));
+    String mcpu = Downcast<String>(target.at("mcpu"));
     arch = ExtractIntWithPrefix(mcpu, "sm_");
     ICHECK(arch != -1) << "ValueError: NVPTX target gets an invalid CUDA arch: -mcpu=" << mcpu;
   } else {
@@ -193,22 +193,22 @@ Map<String, ObjectRef> UpdateNVPTXAttrs(Map<String, ObjectRef> attrs) {
     } else {
       arch = std::stod(version.operator std::string()) * 10 + 0.1;
     }
-    attrs.Set("mcpu", String("sm_") + std::to_string(arch));
+    target.Set("mcpu", String("sm_") + std::to_string(arch));
   }
-  return attrs;
+  return target;
 }
 
 /*!
  * \brief Update the attributes in the LLVM ROCm target.
- * \param attrs The original attributes
+ * \param target The Target to update
  * \return The updated attributes
  */
-Map<String, ObjectRef> UpdateROCmAttrs(Map<String, ObjectRef> attrs) {
-  CheckOrSetAttr(&attrs, "mtriple", "amdgcn-amd-amdhsa-hcc");
+TargetJSON UpdateROCmAttrs(TargetJSON target) {
+  CheckOrSetAttr(&target, "mtriple", "amdgcn-amd-amdhsa-hcc");
   // Update -mcpu=gfx
   int arch;
-  if (attrs.count("mcpu")) {
-    String mcpu = Downcast<String>(attrs.at("mcpu"));
+  if (target.count("mcpu")) {
+    String mcpu = Downcast<String>(target.at("mcpu"));
     arch = ExtractIntWithPrefix(mcpu, "gfx");
     ICHECK(arch != -1) << "ValueError: ROCm target gets an invalid GFX version: -mcpu=" << mcpu;
   } else {
@@ -219,7 +219,7 @@ Map<String, ObjectRef> UpdateROCmAttrs(Map<String, ObjectRef> attrs) {
     } else {
       arch = val.operator int();
     }
-    attrs.Set("mcpu", String("gfx") + std::to_string(arch));
+    target.Set("mcpu", String("gfx") + std::to_string(arch));
   }
   // Update -mattr before ROCm 3.5:
   //   Before ROCm 3.5 we needed code object v2, starting
@@ -235,13 +235,13 @@ Map<String, ObjectRef> UpdateROCmAttrs(Map<String, ObjectRef> attrs) {
   }
   if (version < 305) {
     Array<String> mattr;
-    if (attrs.count("mattr")) {
-      mattr = Downcast<Array<String>>(attrs.at("mattr"));
+    if (target.count("mattr")) {
+      mattr = Downcast<Array<String>>(target.at("mattr"));
     }
     mattr.push_back("-code-object-v3");
-    attrs.Set("mattr", mattr);
+    target.Set("mattr", mattr);
   }
-  return attrs;
+  return target;
 }
 
 /**********  Register Target kinds and attributes  **********/
@@ -295,7 +295,7 @@ TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA)
     .add_attr_option<Integer>("registers_per_block")
     .add_attr_option<Integer>("max_num_threads", Integer(1024))  // TODO(@zxybazh): deprecate it
     .set_default_keys({"cuda", "gpu"})
-    .set_attrs_preprocessor(UpdateCUDAAttrs);
+    .set_target_parser(UpdateCUDAAttrs);
 
 TVM_REGISTER_TARGET_KIND("nvptx", kDLCUDA)
     .add_attr_option<String>("mcpu")
@@ -304,7 +304,7 @@ TVM_REGISTER_TARGET_KIND("nvptx", kDLCUDA)
     .add_attr_option<Integer>("max_num_threads", Integer(1024))
     .add_attr_option<Integer>("thread_warp_size", Integer(32))
     .set_default_keys({"cuda", "gpu"})
-    .set_attrs_preprocessor(UpdateNVPTXAttrs);
+    .set_target_parser(UpdateNVPTXAttrs);
 
 TVM_REGISTER_TARGET_KIND("rocm", kDLROCM)
     .add_attr_option<String>("mcpu")
@@ -318,7 +318,7 @@ TVM_REGISTER_TARGET_KIND("rocm", kDLROCM)
     .add_attr_option<Integer>("max_shared_memory_per_block", Integer(65536))
     .add_attr_option<Integer>("thread_warp_size", Integer(64))
     .set_default_keys({"rocm", "gpu"})
-    .set_attrs_preprocessor(UpdateROCmAttrs);
+    .set_target_parser(UpdateROCmAttrs);
 
 TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL)
     .add_attr_option<Bool>("system-lib")
diff --git a/tests/cpp/target_test.cc b/tests/cpp/target_test.cc
index 2c85e47e7f..6854fc661d 100644
--- a/tests/cpp/target_test.cc
+++ b/tests/cpp/target_test.cc
@@ -34,6 +34,36 @@ TVM_REGISTER_TARGET_KIND("TestTargetKind", kDLCPU)
     .add_attr_option<Array<String>>("your_names")
     .add_attr_option<Map<String, Integer>>("her_maps");
 
+TargetJSON TestTargetParser(TargetJSON target) {
+  String mcpu = Downcast<String>(target.at("mcpu"));
+  target.Set("mcpu", String("super_") + mcpu);
+  target.Set("keys", Array<String>({"super"}));
+  return target;
+}
+
+Map<String, ObjectRef> TestAttrsPreProcessor(Map<String, ObjectRef> attrs) {
+  attrs.Set("mattr", String("woof"));
+  return attrs;
+}
+
+TVM_REGISTER_TARGET_KIND("TestTargetParser", kDLCPU)
+    .add_attr_option<String>("mattr")
+    .add_attr_option<String>("mcpu")
+    .set_default_keys({"cpu"})
+    .set_target_parser(TestTargetParser);
+
+TVM_REGISTER_TARGET_KIND("TestAttrsPreprocessor", kDLCPU)
+    .add_attr_option<String>("mattr")
+    .set_default_keys({"cpu"})
+    .set_attrs_preprocessor(TestAttrsPreProcessor);
+
+TVM_REGISTER_TARGET_KIND("TestClashingPreprocessor", kDLCPU)
+    .add_attr_option<String>("mattr")
+    .add_attr_option<String>("mcpu")
+    .set_default_keys({"cpu"})
+    .set_attrs_preprocessor(TestAttrsPreProcessor)
+    .set_target_parser(TestTargetParser);
+
 TEST(TargetKind, GetAttrMap) {
   auto map = tvm::TargetKind::GetAttrMap<std::string>("Attr1");
   auto target_kind = tvm::TargetKind::Get("TestTargetKind").value();
@@ -136,6 +166,23 @@ TEST(TargetCreationFail, TargetKindNotFound) {
   ASSERT_EQ(failed, true);
 }
 
+TEST(TargetCreation, TargetParser) {
+  Target test_target("TestTargetParser -mcpu=woof");
+  ASSERT_EQ(test_target->GetAttr<String>("mcpu").value(), "super_woof");
+  ASSERT_EQ(test_target->keys.size(), 2);
+  ASSERT_EQ(test_target->keys[0], "super");
+  ASSERT_EQ(test_target->keys[1], "cpu");
+}
+
+TEST(TargetCreation, TargetAttrsPreProcessor) {
+  Target test_target("TestAttrsPreprocessor -mattr=cake");
+  ASSERT_EQ(test_target->GetAttr<String>("mattr").value(), "woof");
+}
+
+TEST(TargetCreation, ClashingTargetProcessing) {
+  EXPECT_THROW(Target("TestClashingPreprocessor -mcpu=woof -mattr=cake"), InternalError);
+}
+
 TVM_REGISTER_TARGET_KIND("test_external_codegen_0", kDLCUDA)
     .set_attr<Bool>(tvm::attr::kIsExternalCodegen, Bool(true));