You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by kp...@apache.org on 2022/07/18 19:37:35 UTC
[tvm] branch main updated: [Target] Add target_parser to TargetKind (#12119)
This is an automated email from the ASF dual-hosted git repository.
kparzysz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d4309cf810 [Target] Add target_parser to TargetKind (#12119)
d4309cf810 is described below
commit d4309cf81003c5fdeeb75583ac96cc0926a22b25
Author: Christopher Sidebottom <ch...@arm.com>
AuthorDate: Mon Jul 18 20:37:20 2022 +0100
[Target] Add target_parser to TargetKind (#12119)
This adds the `target_parser` as described in https://github.com/apache/tvm-rfcs/pull/71, which parses an incoming `TargetJSON` and produces a new configuration for generating the final `Target` object from.
Marks `set_attrs_preprocessor` as deprecated and errors if both `set_attrs_preprocessor` and `set_target_parser` exist together.
---
docs/arch/device_target_interactions.rst | 4 +--
include/tvm/target/target_kind.h | 23 ++++++++++++++
src/target/target.cc | 17 +++++++++--
src/target/target_kind.cc | 52 ++++++++++++++++----------------
tests/cpp/target_test.cc | 47 +++++++++++++++++++++++++++++
5 files changed, 112 insertions(+), 31 deletions(-)
diff --git a/docs/arch/device_target_interactions.rst b/docs/arch/device_target_interactions.rst
index 9c391d31be..ec8d52226e 100644
--- a/docs/arch/device_target_interactions.rst
+++ b/docs/arch/device_target_interactions.rst
@@ -194,8 +194,8 @@ different code generation targets can run on the same physical device.
device type.)
All options for a specific target kind are added with the
-``add_attr_option`` function, with optional default values. A
-preprocessor can be added with ``set_attrs_preprocessor`` to define
+``add_attr_option`` function, with optional default values. A `Target`
+parser can be added with ``set_target_parser`` to process
any parameters that are dynamically based on other parameters or
queried from device properties.
diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h
index 4879470e76..e20f8547af 100644
--- a/include/tvm/target/target_kind.h
+++ b/include/tvm/target/target_kind.h
@@ -37,6 +37,16 @@ namespace tvm {
class Target;
+/*!
+ * \brief TargetParser to apply on instantiation of a given TargetKind
+ *
+ * \param target_json Target in JSON format to be transformed during parsing.
+ *
+ * \return The transformed Target JSON object.
+ */
+using TargetJSON = Map<String, ObjectRef>;
+using FTVMTargetParser = TypedPackedFunc<TargetJSON(TargetJSON)>;
+
/*!
* \brief RelayToTIR tvm::transform::Pass specific to a TargetKind
*
@@ -82,6 +92,8 @@ class TargetKindNode : public Object {
Array<String> default_keys;
/*! \brief Function used to preprocess on target creation */
PackedFunc preprocessor;
+ /*! \brief Function used to parse a JSON target during creation */
+ FTVMTargetParser target_parser;
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
@@ -207,6 +219,11 @@ class TargetKindRegEntry {
*/
template <typename FLambda>
inline TargetKindRegEntry& set_attrs_preprocessor(FLambda f);
+ /*!
+ * \brief Set the parsing function applied upon target creation
+ * \param parser The Target parsing function
+ */
+ inline TargetKindRegEntry& set_target_parser(FTVMTargetParser parser);
/*!
* \brief Register a valid configuration option and its ValueType for validation
* \param key The configuration key
@@ -353,11 +370,17 @@ inline TargetKindRegEntry& TargetKindRegEntry::set_default_keys(std::vector<Stri
template <typename FLambda>
inline TargetKindRegEntry& TargetKindRegEntry::set_attrs_preprocessor(FLambda f) {
+ LOG(WARNING) << "set_attrs_preprocessor is deprecated please use set_target_parser instead";
using FType = typename tvm::runtime::detail::function_signature<FLambda>::FType;
kind_->preprocessor = tvm::runtime::TypedPackedFunc<FType>(std::move(f)).packed();
return *this;
}
+inline TargetKindRegEntry& TargetKindRegEntry::set_target_parser(FTVMTargetParser parser) {
+ kind_->target_parser = parser;
+ return *this;
+}
+
template <typename ValueType>
inline TargetKindRegEntry& TargetKindRegEntry::add_attr_option(const String& key) {
ICHECK(!kind_->key2vtype_.count(key))
diff --git a/src/target/target.cc b/src/target/target.cc
index 01f9bfaeec..207a399a77 100644
--- a/src/target/target.cc
+++ b/src/target/target.cc
@@ -52,7 +52,7 @@ class TargetInternal {
static ObjectPtr<Object> FromString(const String& tag_or_config_or_target_str);
static ObjectPtr<Object> FromConfigString(const String& config_str);
static ObjectPtr<Object> FromRawString(const String& target_str);
- static ObjectPtr<Object> FromConfig(std::unordered_map<String, ObjectRef> config);
+ static ObjectPtr<Object> FromConfig(Map<String, ObjectRef> config);
static void ConstructorDispatcher(TVMArgs args, TVMRetValue* rv);
static Target WithHost(const Target& target, const Target& target_host) {
ObjectPtr<TargetNode> n = make_object<TargetNode>(*target.get());
@@ -716,17 +716,27 @@ ObjectPtr<Object> TargetInternal::FromRawString(const String& target_str) {
return TargetInternal::FromConfig(config);
}
-ObjectPtr<Object> TargetInternal::FromConfig(std::unordered_map<String, ObjectRef> config) {
+ObjectPtr<Object> TargetInternal::FromConfig(Map<String, ObjectRef> config) {
const String kKind = "kind";
const String kTag = "tag";
const String kKeys = "keys";
const String kDeviceName = "device";
const String kHost = "host";
ObjectPtr<TargetNode> target = make_object<TargetNode>();
+
// parse 'kind'
if (config.count(kKind)) {
if (const auto* kind = config[kKind].as<StringObj>()) {
target->kind = GetTargetKind(GetRef<String>(kind));
+ ICHECK(!(target->kind->preprocessor != nullptr && target->kind->target_parser != nullptr))
+ << "Cannot use both set_attrs_preprocessor and set_target_parser";
+
+ // Run JSON Parser over JSON input
+ if (target->kind->target_parser != nullptr) {
+ VLOG(9) << "TargetInternal::FromConfig - Running target_parser";
+ config = target->kind->target_parser(config);
+ }
+
config.erase(kKind);
} else {
throw Error(": Expect type of field \"kind\" is String, but get type: " +
@@ -828,8 +838,9 @@ ObjectPtr<Object> TargetInternal::FromConfig(std::unordered_map<String, ObjectRe
} else {
target->attrs = attrs;
}
+
return target;
-}
+} // namespace tvm
std::unordered_map<String, ObjectRef> TargetInternal::QueryDevice(int device_id,
const TargetNode* target) {
diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc
index 1148013706..7620c6fc2e 100644
--- a/src/target/target_kind.cc
+++ b/src/target/target_kind.cc
@@ -145,15 +145,15 @@ void CheckOrSetAttr(Map<String, ObjectRef>* attrs, const String& name, const Str
/*!
* \brief Update the attributes in the CUDA target.
- * \param attrs The original attributes
+ * \param target The Target to update
* \return The updated attributes
*/
-Map<String, ObjectRef> UpdateCUDAAttrs(Map<String, ObjectRef> attrs) {
+TargetJSON UpdateCUDAAttrs(TargetJSON target) {
// Update -arch=sm_xx
int archInt;
- if (attrs.count("arch")) {
+ if (target.count("arch")) {
// If -arch has been specified, validate the correctness
- String archStr = Downcast<String>(attrs.at("arch"));
+ String archStr = Downcast<String>(target.at("arch"));
archInt = ExtractIntWithPrefix(archStr, "sm_");
ICHECK(archInt != -1) << "ValueError: CUDA target gets an invalid CUDA arch: -arch=" << archStr;
} else {
@@ -165,23 +165,23 @@ Map<String, ObjectRef> UpdateCUDAAttrs(Map<String, ObjectRef> attrs) {
} else {
archInt = std::stod(version.operator std::string()) * 10 + 0.1;
}
- attrs.Set("arch", String("sm_") + std::to_string(archInt));
+ target.Set("arch", String("sm_") + std::to_string(archInt));
}
- return attrs;
+ return target;
}
/*!
* \brief Update the attributes in the LLVM NVPTX target.
- * \param attrs The original attributes
+ * \param target The Target to update
* \return The updated attributes
*/
-Map<String, ObjectRef> UpdateNVPTXAttrs(Map<String, ObjectRef> attrs) {
- CheckOrSetAttr(&attrs, "mtriple", "nvptx64-nvidia-cuda");
+TargetJSON UpdateNVPTXAttrs(TargetJSON target) {
+ CheckOrSetAttr(&target, "mtriple", "nvptx64-nvidia-cuda");
// Update -mcpu=sm_xx
int arch;
- if (attrs.count("mcpu")) {
+ if (target.count("mcpu")) {
// If -mcpu has been specified, validate the correctness
- String mcpu = Downcast<String>(attrs.at("mcpu"));
+ String mcpu = Downcast<String>(target.at("mcpu"));
arch = ExtractIntWithPrefix(mcpu, "sm_");
ICHECK(arch != -1) << "ValueError: NVPTX target gets an invalid CUDA arch: -mcpu=" << mcpu;
} else {
@@ -193,22 +193,22 @@ Map<String, ObjectRef> UpdateNVPTXAttrs(Map<String, ObjectRef> attrs) {
} else {
arch = std::stod(version.operator std::string()) * 10 + 0.1;
}
- attrs.Set("mcpu", String("sm_") + std::to_string(arch));
+ target.Set("mcpu", String("sm_") + std::to_string(arch));
}
- return attrs;
+ return target;
}
/*!
* \brief Update the attributes in the LLVM ROCm target.
- * \param attrs The original attributes
+ * \param target The Target to update
* \return The updated attributes
*/
-Map<String, ObjectRef> UpdateROCmAttrs(Map<String, ObjectRef> attrs) {
- CheckOrSetAttr(&attrs, "mtriple", "amdgcn-amd-amdhsa-hcc");
+TargetJSON UpdateROCmAttrs(TargetJSON target) {
+ CheckOrSetAttr(&target, "mtriple", "amdgcn-amd-amdhsa-hcc");
// Update -mcpu=gfx
int arch;
- if (attrs.count("mcpu")) {
- String mcpu = Downcast<String>(attrs.at("mcpu"));
+ if (target.count("mcpu")) {
+ String mcpu = Downcast<String>(target.at("mcpu"));
arch = ExtractIntWithPrefix(mcpu, "gfx");
ICHECK(arch != -1) << "ValueError: ROCm target gets an invalid GFX version: -mcpu=" << mcpu;
} else {
@@ -219,7 +219,7 @@ Map<String, ObjectRef> UpdateROCmAttrs(Map<String, ObjectRef> attrs) {
} else {
arch = val.operator int();
}
- attrs.Set("mcpu", String("gfx") + std::to_string(arch));
+ target.Set("mcpu", String("gfx") + std::to_string(arch));
}
// Update -mattr before ROCm 3.5:
// Before ROCm 3.5 we needed code object v2, starting
@@ -235,13 +235,13 @@ Map<String, ObjectRef> UpdateROCmAttrs(Map<String, ObjectRef> attrs) {
}
if (version < 305) {
Array<String> mattr;
- if (attrs.count("mattr")) {
- mattr = Downcast<Array<String>>(attrs.at("mattr"));
+ if (target.count("mattr")) {
+ mattr = Downcast<Array<String>>(target.at("mattr"));
}
mattr.push_back("-code-object-v3");
- attrs.Set("mattr", mattr);
+ target.Set("mattr", mattr);
}
- return attrs;
+ return target;
}
/********** Register Target kinds and attributes **********/
@@ -295,7 +295,7 @@ TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA)
.add_attr_option<Integer>("registers_per_block")
.add_attr_option<Integer>("max_num_threads", Integer(1024)) // TODO(@zxybazh): deprecate it
.set_default_keys({"cuda", "gpu"})
- .set_attrs_preprocessor(UpdateCUDAAttrs);
+ .set_target_parser(UpdateCUDAAttrs);
TVM_REGISTER_TARGET_KIND("nvptx", kDLCUDA)
.add_attr_option<String>("mcpu")
@@ -304,7 +304,7 @@ TVM_REGISTER_TARGET_KIND("nvptx", kDLCUDA)
.add_attr_option<Integer>("max_num_threads", Integer(1024))
.add_attr_option<Integer>("thread_warp_size", Integer(32))
.set_default_keys({"cuda", "gpu"})
- .set_attrs_preprocessor(UpdateNVPTXAttrs);
+ .set_target_parser(UpdateNVPTXAttrs);
TVM_REGISTER_TARGET_KIND("rocm", kDLROCM)
.add_attr_option<String>("mcpu")
@@ -318,7 +318,7 @@ TVM_REGISTER_TARGET_KIND("rocm", kDLROCM)
.add_attr_option<Integer>("max_shared_memory_per_block", Integer(65536))
.add_attr_option<Integer>("thread_warp_size", Integer(64))
.set_default_keys({"rocm", "gpu"})
- .set_attrs_preprocessor(UpdateROCmAttrs);
+ .set_target_parser(UpdateROCmAttrs);
TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL)
.add_attr_option<Bool>("system-lib")
diff --git a/tests/cpp/target_test.cc b/tests/cpp/target_test.cc
index 2c85e47e7f..6854fc661d 100644
--- a/tests/cpp/target_test.cc
+++ b/tests/cpp/target_test.cc
@@ -34,6 +34,36 @@ TVM_REGISTER_TARGET_KIND("TestTargetKind", kDLCPU)
.add_attr_option<Array<String>>("your_names")
.add_attr_option<Map<String, Integer>>("her_maps");
+TargetJSON TestTargetParser(TargetJSON target) {
+ String mcpu = Downcast<String>(target.at("mcpu"));
+ target.Set("mcpu", String("super_") + mcpu);
+ target.Set("keys", Array<String>({"super"}));
+ return target;
+}
+
+Map<String, ObjectRef> TestAttrsPreProcessor(Map<String, ObjectRef> attrs) {
+ attrs.Set("mattr", String("woof"));
+ return attrs;
+}
+
+TVM_REGISTER_TARGET_KIND("TestTargetParser", kDLCPU)
+ .add_attr_option<String>("mattr")
+ .add_attr_option<String>("mcpu")
+ .set_default_keys({"cpu"})
+ .set_target_parser(TestTargetParser);
+
+TVM_REGISTER_TARGET_KIND("TestAttrsPreprocessor", kDLCPU)
+ .add_attr_option<String>("mattr")
+ .set_default_keys({"cpu"})
+ .set_attrs_preprocessor(TestAttrsPreProcessor);
+
+TVM_REGISTER_TARGET_KIND("TestClashingPreprocessor", kDLCPU)
+ .add_attr_option<String>("mattr")
+ .add_attr_option<String>("mcpu")
+ .set_default_keys({"cpu"})
+ .set_attrs_preprocessor(TestAttrsPreProcessor)
+ .set_target_parser(TestTargetParser);
+
TEST(TargetKind, GetAttrMap) {
auto map = tvm::TargetKind::GetAttrMap<std::string>("Attr1");
auto target_kind = tvm::TargetKind::Get("TestTargetKind").value();
@@ -136,6 +166,23 @@ TEST(TargetCreationFail, TargetKindNotFound) {
ASSERT_EQ(failed, true);
}
+TEST(TargetCreation, TargetParser) {
+ Target test_target("TestTargetParser -mcpu=woof");
+ ASSERT_EQ(test_target->GetAttr<String>("mcpu").value(), "super_woof");
+ ASSERT_EQ(test_target->keys.size(), 2);
+ ASSERT_EQ(test_target->keys[0], "super");
+ ASSERT_EQ(test_target->keys[1], "cpu");
+}
+
+TEST(TargetCreation, TargetAttrsPreProcessor) {
+ Target test_target("TestAttrsPreprocessor -mattr=cake");
+ ASSERT_EQ(test_target->GetAttr<String>("mattr").value(), "woof");
+}
+
+TEST(TargetCreation, ClashingTargetProcessing) {
+ EXPECT_THROW(Target("TestClashingPreprocessor -mcpu=woof -mattr=cake"), InternalError);
+}
+
TVM_REGISTER_TARGET_KIND("test_external_codegen_0", kDLCUDA)
.set_attr<Bool>(tvm::attr::kIsExternalCodegen, Bool(true));