You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/07/16 04:17:38 UTC
[tvm] branch main updated: [MetaSchedule] Allow MultiLevelTilingTensorCore rule to specify multiple tensor intrin groups (#12113)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 895f79f42a [MetaSchedule] Allow MultiLevelTilingTensorCore rule to specify multiple tensor intrin groups (#12113)
895f79f42a is described below
commit 895f79f42a1c798f5b70b0689ae26ae159031302
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Fri Jul 15 21:17:31 2022 -0700
[MetaSchedule] Allow MultiLevelTilingTensorCore rule to specify multiple tensor intrin groups (#12113)
---
include/tvm/meta_schedule/schedule_rule.h | 21 +--
.../schedule_rule/multi_level_tiling.py | 16 +-
python/tvm/meta_schedule/testing/schedule_rule.py | 26 +++-
.../multi_level_tiling_tensor_core.cc | 145 ++++++++++++-------
...ta_schedule_schedule_rule_multi_level_tiling.py | 161 +++++++++++++++++----
5 files changed, 260 insertions(+), 109 deletions(-)
diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h
index 5e4698db17..b5f4a17b69 100644
--- a/include/tvm/meta_schedule/schedule_rule.h
+++ b/include/tvm/meta_schedule/schedule_rule.h
@@ -174,13 +174,13 @@ class ScheduleRule : public runtime::ObjectRef {
Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write);
/*!
- * \brief Extension of MultiLevelTiling for auto-tensorizing with a single group of tensor core
- * intrinsics
- * \param intrin_group A group of tensor core intrinsics. The map should contains key "init",
- * "load_a", "load_b", "compute", "store", which represent the tensor intrin for initialization,
- * loading operand A, loading operand B, tensor core computation, storing the result. The value of
- * the map should be names of tensor intrinsics, must be registerd via TensorIntrin.register(...)
- * beforehand
+ * \brief Extension of MultiLevelTiling for auto-tensorizing with multiple groups of candidate
+ * tensor core intrinsics
+ * \param intrin_groups A list of groups of tensor core intrinsics. The map should contains key
+ * "init", "load_a", "load_b", "compute", "store", which represent the tensor intrin for
+ * initialization, loading operand A, loading operand B, tensor core computation, storing the
+ * result. The value of the map should be names of tensor intrinsics, must be registerd via
+ * TensorIntrin.register(...) beforehand
* \param structure The tiling structure. Recommended:
* - 'SSSRRSRS' on GPU
* \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended:
@@ -193,9 +193,10 @@ class ScheduleRule : public runtime::ObjectRef {
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule MultiLevelTilingTensorCore(
- Map<String, String> intrin_group, String structure, Optional<Array<String>> tile_binds,
- Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
- Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write);
+ Array<Map<String, String>> intrin_groups, String structure,
+ Optional<Array<String>> tile_binds, Optional<Integer> max_innermost_factor,
+ Optional<Array<Integer>> vector_load_lens, Optional<Map<String, ObjectRef>> reuse_read,
+ Optional<Map<String, ObjectRef>> reuse_write);
/*!
* \brief Create a rule: add-rfactor to some blocks if needed
diff --git a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
index 71fbaee4f6..a728a91eb7 100644
--- a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
+++ b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
@@ -135,15 +135,15 @@ class MultiLevelTilingWithIntrin(ScheduleRule):
@register_object("meta_schedule.MultiLevelTilingTensorCore")
class MultiLevelTilingTensorCore(ScheduleRule):
- """Extension of MultiLevelTiling for auto-tensorizing with a single group of tensor core
- intrinsics.
+ """Extension of MultiLevelTiling for auto-tensorizing with multiple groups of candidate tensor
+ core intrinsics.
Parameters
----------
- intrin_group : Mapping[str, str]
- A group of tensor core intrinsics. The map should contains key "init", "load_a", "load_b",
- "compute", "store", which represent the tensor intrin for initialization, loading operand A,
- loading operand B, tensor core computation, storing the result.
+ intrin_groups : List[Mapping[str, str]]
+ A list of groups of tensor core intrinsics. The map should contains key "init", "load_a",
+ "load_b", "compute", "store", which represent the tensor intrin for initialization,
+ loading operand A, loading operand B, tensor core computation, storing the result.
The value of the map should be names of tensor intrinsics, must be registerd via
TensorIntrin.register(...) beforehand
structure : str
@@ -165,7 +165,7 @@ class MultiLevelTilingTensorCore(ScheduleRule):
def __init__(
self,
- intrin_group: Mapping[str, str],
+ intrin_groups: List[Mapping[str, str]],
structure: str,
tile_binds: Optional[List[str]] = None,
max_innermost_factor: Optional[int] = None,
@@ -175,7 +175,7 @@ class MultiLevelTilingTensorCore(ScheduleRule):
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ScheduleRuleMultiLevelTilingTensorCore, # type: ignore # pylint: disable=no-member
- intrin_group,
+ intrin_groups,
structure,
tile_binds,
max_innermost_factor,
diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py
index 717be59512..ea748ddc05 100644
--- a/python/tvm/meta_schedule/testing/schedule_rule.py
+++ b/python/tvm/meta_schedule/testing/schedule_rule.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Default schedule rules"""
+from typing import List, Union
from tvm.meta_schedule.schedule_rule import (
AddRFactor,
AutoBind,
@@ -114,18 +115,29 @@ def multi_level_tiling(target: Target) -> ScheduleRule:
def multi_level_tiling_tensor_core(
target: Target,
- write_reuse_scope="shared",
- in_dtype="float16",
- out_dtype="float32",
- trans_b=False,
+ write_reuse_scope: str = "shared",
+ in_dtype: Union[str, List[str]] = "float16",
+ out_dtype: Union[str, List[str]] = "float32",
+ trans_b: Union[bool, List[bool]] = False,
) -> ScheduleRule:
"""Default schedule rules for with multi-level tiling reuse for tensor core"""
assert write_reuse_scope in ["shared", "global"]
+ if not isinstance(in_dtype, list):
+ in_dtype = [in_dtype]
+ if not isinstance(out_dtype, list):
+ out_dtype = [out_dtype]
+ if not isinstance(trans_b, list):
+ trans_b = [trans_b]
+
if target.kind.name == "cuda":
+ intrin_groups = [
+ tensor_intrin.get_wmma_intrin_group(write_reuse_scope, _in_dtype, _out_dtype, _trans_b)
+ for _in_dtype in in_dtype
+ for _out_dtype in out_dtype
+ for _trans_b in trans_b
+ ]
return MultiLevelTilingTensorCore(
- intrin_group=tensor_intrin.get_wmma_intrin_group(
- write_reuse_scope, in_dtype, out_dtype, trans_b
- ),
+ intrin_groups=intrin_groups,
structure="SSSRRSRS",
tile_binds=["blockIdx.y", "blockIdx.x", "threadIdx.y"],
max_innermost_factor=4, # 64 // tensor intrin size
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
index 91df62fc36..6d34f7b64e 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
@@ -38,10 +38,42 @@ struct TensorCoreIntrinGroup {
String load_b_intrin;
String compute_intrin;
String store_intrin;
+
+ /*! \brief Create TensorCoreIntrinGroup from config in a map. The map should contains the
+ * following keys:
+ * - init
+ * - load_a
+ * - load_b
+ * - compute
+ * - store
+ * The values of the keys should be the names of the corresponding intrinsics and should be
+ * registered via TensorIntrin.Register beforehand.
+ */
+ static TensorCoreIntrinGroup FromConfig(const Map<String, String>& config);
};
+TensorCoreIntrinGroup TensorCoreIntrinGroup::FromConfig(const Map<String, String>& config) {
+ auto f_initialize_intrin = [&config](String key_name, String* intrin_name) {
+ CHECK(config.count(key_name)) << "ValueError: " << key_name << " is not set.";
+ *intrin_name = config.at(key_name);
+ // Check the existence of the intrin
+ tir::TensorIntrin::Get(*intrin_name);
+ };
+ TensorCoreIntrinGroup intrin_group;
+ f_initialize_intrin("init", &intrin_group.init_intrin);
+ f_initialize_intrin("load_a", &intrin_group.load_a_intrin);
+ f_initialize_intrin("load_b", &intrin_group.load_b_intrin);
+ f_initialize_intrin("compute", &intrin_group.compute_intrin);
+ f_initialize_intrin("store", &intrin_group.store_intrin);
+ return intrin_group;
+}
+
class TensorCoreStateNode : public StateNode {
public:
+ /*! \brief The tensor core intrinsic group. */
+ TensorCoreIntrinGroup intrin_group;
+ /*! \brief The auto tensorization maping info. */
+ tir::AutoTensorizeMappingInfo mapping_info{nullptr};
/*! \brief The Tensor Core reindex block A for Tensor Core computation */
tir::BlockRV tensor_core_reindex_A;
/*! \brief The Tensor Core reindex block B for Tensor Core computation */
@@ -57,16 +89,21 @@ class TensorCoreStateNode : public StateNode {
class TensorCoreState : public State {
public:
- explicit TensorCoreState(tir::Schedule sch, tir::BlockRV block_rv,
- Array<Array<tir::LoopRV>> tiles = {});
+ explicit TensorCoreState(TensorCoreIntrinGroup intrin_group,
+ tir::AutoTensorizeMappingInfo mapping_info, Schedule sch,
+ BlockRV block_rv, Array<Array<tir::LoopRV>> tiles = {});
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorCoreState, State, TensorCoreStateNode);
};
TVM_REGISTER_OBJECT_TYPE(TensorCoreStateNode);
-TensorCoreState::TensorCoreState(Schedule sch, BlockRV block_rv, Array<Array<LoopRV>> tiles) {
+TensorCoreState::TensorCoreState(TensorCoreIntrinGroup intrin_group,
+ tir::AutoTensorizeMappingInfo mapping_info, Schedule sch,
+ BlockRV block_rv, Array<Array<LoopRV>> tiles) {
ObjectPtr<TensorCoreStateNode> node = make_object<TensorCoreStateNode>();
+ node->intrin_group = intrin_group;
+ node->mapping_info = mapping_info;
node->sch = std::move(sch);
node->block_rv = std::move(block_rv);
node->tiles = std::move(tiles);
@@ -116,16 +153,12 @@ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode {
const String& intrin_name) const;
public:
- /*! \brief The tensor core intrin group to apply */
- TensorCoreIntrinGroup intrin_group;
+ /*! \brief The candidate tensor core intrin groups to apply */
+ std::vector<TensorCoreIntrinGroup> intrin_groups;
static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingTensorCore";
TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingTensorCoreNode, MultiLevelTilingNode);
private:
- /*!
- * \brief The mapping info for auto tensorization
- */
- tir::AutoTensorizeMappingInfo mapping_info_{nullptr};
};
// Entry of the mega rule; Inherited from ScheduleRuleNode
@@ -135,21 +168,36 @@ Array<Schedule> MultiLevelTilingTensorCoreNode::Apply(const Schedule& sch,
return {sch};
}
- Optional<tir::AutoTensorizeMappingInfo> mapping_info =
- tir::GetAutoTensorizeMappingInfo(sch->state(), sch->GetSRef(block_rv),
- tir::TensorIntrin::Get(intrin_group.compute_intrin)->desc);
- if (!mapping_info.defined()) {
+ std::unordered_map<int, tir::AutoTensorizeMappingInfo> intrin_group_to_mapping_info;
+ for (int i = 0, n = intrin_groups.size(); i < n; ++i) {
+ TensorCoreIntrinGroup intrin_group = intrin_groups[i];
+ Optional<tir::AutoTensorizeMappingInfo> mapping_info = tir::GetAutoTensorizeMappingInfo(
+ sch->state(), sch->GetSRef(block_rv),
+ tir::TensorIntrin::Get(intrin_groups[i].compute_intrin)->desc);
+ if (mapping_info.defined()) {
+ intrin_group_to_mapping_info.emplace(i, mapping_info.value());
+ }
+ }
+
+ if (intrin_group_to_mapping_info.empty()) {
+ // No tensor intrinsics can be applied.
return {sch};
}
- mapping_info_ = mapping_info.value();
- // Create a copy of the schedule so that we can roll back transformations if tensorization
+ // Save the original schedule so that we can roll back transformations if tensorization
// fail.
- Schedule original_sch = sch->Copy();
- sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, structure);
-
+ Schedule original_sch = sch;
+
+ std::vector<State> initial_states;
+ for (const auto& kv : intrin_group_to_mapping_info) {
+ const TensorCoreIntrinGroup& intrin_group = intrin_groups[kv.first];
+ const tir::AutoTensorizeMappingInfo& mapping_info = kv.second;
+ Schedule new_sch = sch->Copy();
+ new_sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, structure);
+ initial_states.push_back(TensorCoreState(intrin_group, mapping_info, new_sch, block_rv));
+ }
Array<Schedule> results;
- for (auto&& state : ApplySubRules({TensorCoreState(sch, block_rv)})) {
+ for (auto&& state : ApplySubRules(initial_states)) {
results.push_back(std::move(state->sch));
}
if (results.empty()) {
@@ -196,7 +244,7 @@ std::vector<State> MultiLevelTilingTensorCoreNode::AddWriteReuseTensorCore(
AnnotateCooperativeFetching(&sch, state->write_reuse[0]);
}
sch->ReverseComputeInline(state->tensor_core_reindex_store);
- TileAndAnnotateTensorize(&sch, cache_write, intrin_group.store_intrin);
+ TileAndAnnotateTensorize(&sch, cache_write, state->intrin_group.store_intrin);
return {state};
}
@@ -212,8 +260,8 @@ std::vector<State> MultiLevelTilingTensorCoreNode::AddReadReuseTensorCore(
TileAndAnnotateTensorize(&sch, cache_read, intrin_name);
};
- f_tensorize_load(0, "wmma.matrix_a", intrin_group.load_a_intrin);
- f_tensorize_load(1, "wmma.matrix_b", intrin_group.load_b_intrin);
+ f_tensorize_load(0, "wmma.matrix_a", state->intrin_group.load_a_intrin);
+ f_tensorize_load(1, "wmma.matrix_b", state->intrin_group.load_b_intrin);
sch->ComputeInline(state->tensor_core_reindex_A);
sch->ComputeInline(state->tensor_core_reindex_B);
@@ -238,6 +286,7 @@ std::vector<State> MultiLevelTilingTensorCoreNode::AddReadReuseTensorCore(
Optional<LoopRV> MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin(
TensorCoreStateNode* state, const String& intrin_name) const {
BlockRV block_rv = state->block_rv;
+ const tir::AutoTensorizeMappingInfo& mapping_info = state->mapping_info;
tir::StmtSRef block_sref = state->sch->GetSRef(state->block_rv);
// Add reindex stages
@@ -258,24 +307,24 @@ Optional<LoopRV> MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin(
// Transform the layout of reindex buffers accordingly.
// The index map defines the mapping for the computation block. We need to extract the sub index
// map to transform the load and store block.
- ICHECK_EQ(mapping_info_->mappings.size(), 1U); // assume only one mapping is present
- const tir::IndexMap& index_map = mapping_info_->mappings[0];
+ ICHECK_EQ(mapping_info->mappings.size(), 1U); // assume only one mapping is present
+ const tir::IndexMap& index_map = mapping_info->mappings[0];
// Find the correspondence between block iters and the iters in the index map.
std::unordered_map<tir::Var, tir::Var, ObjectPtrHash, ObjectPtrEqual> lhs_to_index_map_src;
std::unordered_map<tir::Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> rhs_to_index_map_tgt;
std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> unmapped_index_map_src;
- ICHECK_EQ(mapping_info_->lhs_iters.size(), index_map->initial_indices.size());
- for (int i = 0; i < static_cast<int>(mapping_info_->lhs_iters.size()); ++i) {
- lhs_to_index_map_src[mapping_info_->lhs_iters[i]->var] = index_map->initial_indices[i];
+ ICHECK_EQ(mapping_info->lhs_iters.size(), index_map->initial_indices.size());
+ for (int i = 0; i < static_cast<int>(mapping_info->lhs_iters.size()); ++i) {
+ lhs_to_index_map_src[mapping_info->lhs_iters[i]->var] = index_map->initial_indices[i];
}
// The number of result iters in the index map is equal or more than the number of rhs (the
- // tensor intrin) iters. When there are extra iters, these iters represent unmapped iters from the
- // lhs. They will be skipped during pattern matching for tensorization.
- // An example of such case is batch matmul, the batch dimension is kept after layout
- // transformations and it will be kept as a outer loop after tensorization.
+ // tensor intrin) iters. When there are extra iters, these iters represent unmapped iters from
+ // the lhs. They will be skipped during pattern matching for tensorization. An example of such
+ // case is batch matmul, the batch dimension is kept after layout transformations and it will be
+ // kept as a outer loop after tensorization.
int offset = static_cast<int>(index_map->final_indices.size()) -
- static_cast<int>(mapping_info_->rhs_iters.size());
+ static_cast<int>(mapping_info->rhs_iters.size());
ICHECK_GE(offset, 0);
for (int i = 0; i < offset; ++i) {
const tir::VarNode* var_ptr = index_map->final_indices[i].as<tir::VarNode>();
@@ -283,13 +332,13 @@ Optional<LoopRV> MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin(
unmapped_index_map_src.insert(GetRef<tir::Var>(var_ptr));
}
for (int i = offset; i < static_cast<int>(index_map->final_indices.size()); ++i) {
- rhs_to_index_map_tgt[mapping_info_->rhs_iters[i - offset]->var] = index_map->final_indices[i];
+ rhs_to_index_map_tgt[mapping_info->rhs_iters[i - offset]->var] = index_map->final_indices[i];
}
auto f_get_sub_index_map = [&](const tir::Buffer& lhs_buffer, const tir::Region& lhs_region) {
std::vector<tir::Var> sub_index_map_src;
std::vector<PrimExpr> sub_index_map_tgt;
- const tir::Buffer& rhs_buffer = mapping_info_->lhs_buffer_map[lhs_buffer];
+ const tir::Buffer& rhs_buffer = mapping_info->lhs_buffer_map[lhs_buffer];
for (const Range& range : lhs_region) {
ICHECK(tir::is_one(range->extent));
const tir::VarNode* var_ptr = range->min.as<tir::VarNode>();
@@ -300,8 +349,8 @@ Optional<LoopRV> MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin(
sub_index_map_tgt.push_back(lhs_representer);
}
}
- for (size_t i = 0; i < mapping_info_->rhs_buffer_indices[rhs_buffer].size(); ++i) {
- const tir::VarNode* var = mapping_info_->rhs_buffer_indices[rhs_buffer][i].as<tir::VarNode>();
+ for (size_t i = 0; i < mapping_info->rhs_buffer_indices[rhs_buffer].size(); ++i) {
+ const tir::VarNode* var = mapping_info->rhs_buffer_indices[rhs_buffer][i].as<tir::VarNode>();
ICHECK(var != nullptr);
sub_index_map_tgt.push_back(rhs_to_index_map_tgt[GetRef<tir::Var>(var)]);
}
@@ -345,7 +394,7 @@ inline std::vector<State> MultiLevelTilingTensorCoreNode::TransformForTensorizat
TensorCoreState state) const {
// Do reindex and layout transformations.
Optional<LoopRV> transformed_loop_rv =
- TransformWithTensorIntrin(state.operator->(), intrin_group.compute_intrin);
+ TransformWithTensorIntrin(state.operator->(), state->intrin_group.compute_intrin);
if (!transformed_loop_rv.defined()) {
// The workload can't be tensorized.
return {};
@@ -356,32 +405,24 @@ inline std::vector<State> MultiLevelTilingTensorCoreNode::TransformForTensorizat
// Add annotations for post processors.
state->sch->Annotate(state->block_rv, tir::attr::meta_schedule_auto_tensorize,
- intrin_group.compute_intrin);
+ state->intrin_group.compute_intrin);
state->sch->Annotate(state->block_rv, tir::attr::meta_schedule_auto_tensorize_init,
- intrin_group.init_intrin);
+ state->intrin_group.init_intrin);
state->sch->Annotate(state->block_rv, tir::attr::warp_execution, Bool(true));
return {std::move(state)};
}
ScheduleRule ScheduleRule::MultiLevelTilingTensorCore(
- Map<String, String> intrin_group, String structure, Optional<Array<String>> tile_binds,
+ Array<Map<String, String>> intrin_groups, String structure, Optional<Array<String>> tile_binds,
Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write) {
auto node = MultiLevelTilingInitCommon<MultiLevelTilingTensorCoreNode>(
structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write);
- auto f_initialize_intrin = [&intrin_group](String key_name, String* intrin_name) {
- CHECK(intrin_group.count(key_name)) << "ValueError: " << key_name << " is not set.";
- *intrin_name = intrin_group.at(key_name);
- // Check the existence of the intrin
- tir::TensorIntrin::Get(*intrin_name);
- };
- f_initialize_intrin("init", &node->intrin_group.init_intrin);
- f_initialize_intrin("load_a", &node->intrin_group.load_a_intrin);
- f_initialize_intrin("load_b", &node->intrin_group.load_b_intrin);
- f_initialize_intrin("compute", &node->intrin_group.compute_intrin);
- f_initialize_intrin("store", &node->intrin_group.store_intrin);
-
+ node->intrin_groups.reserve(intrin_groups.size());
+ for (const auto& intrin_group_config : intrin_groups) {
+ node->intrin_groups.emplace_back(TensorCoreIntrinGroup::FromConfig(intrin_group_config));
+ }
return ScheduleRule(node);
}
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
index 1ceef0afc3..c43645832b 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
@@ -563,25 +563,6 @@ sch.annotate(block_or_loop=b49, ann_key="meta_schedule.cooperative_fetch", ann_v
check_trace(spaces, expected)
-def test_cuda_tensor_core_conv2d():
- target = Target("cuda", host="llvm")
- ctx = _create_context(
- create_prim_func(
- te_workload.conv2d_nhwc_f16(
- N=1, H=16, W=16, CI=16, CO=16, kernel_size=3, stride=1, padding=1
- )
- ),
- target,
- multi_level_tiling_tensor_core(target=target, write_reuse_scope="shared"),
- )
- spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
- assert len(spaces) == 1
-
- expected = []
- print("".join(spaces[0].trace.as_python()))
- check_trace(spaces, expected)
-
-
def test_cuda_tensor_core_matmul_relu():
m = n = k = 128
target = Target("cuda", host="llvm")
@@ -719,14 +700,15 @@ sch.reverse_compute_inline(block=b1)""".split(
def test_cuda_tensor_core_matmul_relu_global():
m = n = k = 128
target = Target("cuda", host="llvm")
- ctx = _create_context(
- create_prim_func(
- te_workload.matmul_relu_fp16(
- n=n,
- m=m,
- k=k,
- ),
+ workload = create_prim_func(
+ te_workload.matmul_relu_fp16(
+ n=n,
+ m=m,
+ k=k,
),
+ )
+ ctx = _create_context(
+ workload,
target=target,
rule=[
multi_level_tiling_tensor_core(target=target, write_reuse_scope="global"),
@@ -822,6 +804,106 @@ sch.storage_align(block=b79, buffer_index=0, axis=-2, factor=32, offset=8)""".sp
]
check_trace(spaces, expected)
+ ctx = _create_context(
+ workload,
+ target=target,
+ rule=[
+ multi_level_tiling_tensor_core(
+ target=target, write_reuse_scope="global", trans_b=[False, True]
+ ),
+ auto_inline(target),
+ ],
+ )
+ spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
+ assert len(spaces) == 2
+
+ expected = [
+ expected[0],
+ """b0 = sch.get_block(name="C", func_name="main")
+sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")
+b1 = sch.reindex(block=b0, buffer=("write", 0))
+b2 = sch.reindex(block=b0, buffer=("read", 0))
+b3 = sch.reindex(block=b0, buffer=("read", 1))
+sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda i, k: (i, k, ))
+sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda j, k: (j, k, ))
+sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda i, j: (i, j, ))
+sch.transform_block_layout(block=b1, index_map=lambda i, j, k: (i, j, k, ))
+sch.transform_block_layout(block=b2, index_map=lambda i, j, k: (i, j, k, ))
+sch.transform_block_layout(block=b3, index_map=lambda i, j, k: (i, j, k, ))
+sch.transform_block_layout(block=b0, index_map=lambda i, j, k: (i, j, k, ))
+l4, l5, l6 = sch.get_loops(block=b0)
+l7, l8 = sch.split(loop=l6, factors=[None, 16], preserve_unit_iters=True)
+l9, l10 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True)
+l11, l12 = sch.split(loop=l4, factors=[None, 16], preserve_unit_iters=True)
+l13, l14, l15, l16, l17, l18 = sch.get_loops(block=b0)
+sch.reorder(l15, l17, l12, l10, l8)
+b19 = sch.blockize(loop=l12)
+sch.annotate(block_or_loop=b19, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_sync_16x16x16_f16f16f32_trans")
+sch.annotate(block_or_loop=b19, ann_key="meta_schedule.auto_tensorize_init", ann_val="wmma_fill_16x16x16_f32")
+sch.annotate(block_or_loop=b19, ann_key="warp_execution", ann_val=1)
+l20, l21, l22 = sch.get_loops(block=b19)
+v23, v24, v25, v26, v27 = sch.sample_perfect_tile(loop=l20, n=5, max_innermost_factor=4)
+l28, l29, l30, l31, l32 = sch.split(loop=l20, factors=[v23, v24, v25, v26, v27], preserve_unit_iters=True)
+v33, v34, v35, v36, v37 = sch.sample_perfect_tile(loop=l21, n=5, max_innermost_factor=4)
+l38, l39, l40, l41, l42 = sch.split(loop=l21, factors=[v33, v34, v35, v36, v37], preserve_unit_iters=True)
+v43, v44, v45 = sch.sample_perfect_tile(loop=l22, n=3, max_innermost_factor=4)
+l46, l47, l48 = sch.split(loop=l22, factors=[v43, v44, v45], preserve_unit_iters=True)
+sch.reorder(l28, l38, l29, l39, l30, l40, l46, l47, l31, l41, l48, l32, l42)
+l49 = sch.fuse(l28, l38, preserve_unit_iters=True)
+sch.bind(loop=l49, thread_axis="blockIdx.y")
+l50 = sch.fuse(l29, l39, preserve_unit_iters=True)
+sch.bind(loop=l50, thread_axis="blockIdx.x")
+l51 = sch.fuse(l30, l40, preserve_unit_iters=True)
+sch.bind(loop=l51, thread_axis="threadIdx.y")
+b52 = sch.cache_write(block=b19, write_buffer_index=0, storage_scope="wmma.accumulator")
+sch.reverse_compute_at(block=b52, loop=l51, preserve_unit_loops=True)
+sch.reverse_compute_inline(block=b1)
+l53, l54, l55, l56, l57 = sch.get_loops(block=b52)
+l58, l59 = sch.split(loop=l57, factors=[None, 16], preserve_unit_iters=True)
+l60, l61 = sch.split(loop=l56, factors=[None, 16], preserve_unit_iters=True)
+l62, l63, l64, l65, l66, l67, l68 = sch.get_loops(block=b52)
+sch.reorder(l67, l61, l59)
+b69 = sch.blockize(loop=l61)
+sch.annotate(block_or_loop=b69, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_global")
+b70 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="shared")
+sch.compute_at(block=b70, loop=l46, preserve_unit_loops=True)
+l71, l72, l73, l74, l75, l76 = sch.get_loops(block=b70)
+l77 = sch.fuse(l75, l76, preserve_unit_iters=True)
+v78 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
+sch.annotate(block_or_loop=b70, ann_key="meta_schedule.cooperative_fetch", ann_val=v78)
+b79 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="shared")
+sch.compute_at(block=b79, loop=l46, preserve_unit_loops=True)
+l80, l81, l82, l83, l84, l85 = sch.get_loops(block=b79)
+l86 = sch.fuse(l84, l85, preserve_unit_iters=True)
+v87 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
+sch.annotate(block_or_loop=b79, ann_key="meta_schedule.cooperative_fetch", ann_val=v87)
+b88 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="wmma.matrix_a")
+sch.compute_at(block=b88, loop=l47, preserve_unit_loops=True)
+l89, l90, l91, l92, l93, l94, l95 = sch.get_loops(block=b88)
+l96, l97 = sch.split(loop=l95, factors=[None, 16], preserve_unit_iters=True)
+l98, l99 = sch.split(loop=l94, factors=[None, 16], preserve_unit_iters=True)
+l100, l101, l102, l103, l104, l105, l106, l107, l108 = sch.get_loops(block=b88)
+sch.reorder(l107, l99, l97)
+b109 = sch.blockize(loop=l99)
+sch.annotate(block_or_loop=b109, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a")
+b110 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="wmma.matrix_b")
+sch.compute_at(block=b110, loop=l47, preserve_unit_loops=True)
+l111, l112, l113, l114, l115, l116, l117 = sch.get_loops(block=b110)
+l118, l119 = sch.split(loop=l117, factors=[None, 16], preserve_unit_iters=True)
+l120, l121 = sch.split(loop=l116, factors=[None, 16], preserve_unit_iters=True)
+l122, l123, l124, l125, l126, l127, l128, l129, l130 = sch.get_loops(block=b110)
+sch.reorder(l129, l121, l119)
+b131 = sch.blockize(loop=l121)
+sch.annotate(block_or_loop=b131, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_b_trans")
+sch.compute_inline(block=b2)
+sch.compute_inline(block=b3)
+sch.storage_align(block=b70, buffer_index=0, axis=-2, factor=32, offset=8)
+sch.storage_align(block=b79, buffer_index=0, axis=-2, factor=32, offset=8)""".split(
+ "\n"
+ ),
+ ]
+ check_trace(spaces, expected)
+
def test_multi_level_tiling_non_tensorizable():
# expected to do nothing on non-tensorizable workloads
@@ -850,13 +932,13 @@ def test_multi_level_tiling_non_tensorizable():
def test_cuda_tensor_core_conv2d():
target = Target("cuda", host="llvm")
+ workload = create_prim_func(
+ te_workload.conv2d_nhwc_f16(
+ N=1, H=16, W=16, CI=32, CO=32, kernel_size=3, stride=1, padding=1
+ )
+ )
ctx = _create_context(
- create_prim_func(
- # dtype doesn't match tensor intrin
- te_workload.conv2d_nhwc_f16(
- N=1, H=16, W=16, CI=32, CO=32, kernel_size=3, stride=1, padding=1
- )
- ),
+ workload,
target=target,
rule=multi_level_tiling_tensor_core(target=target, write_reuse_scope="shared"),
)
@@ -955,6 +1037,21 @@ sch.storage_align(block=b94, buffer_index=0, axis=-2, factor=32, offset=8)""".sp
]
check_trace(spaces, expected)
+ # test adding unappliable tensor intrinsics doesn't change the search space
+ ctx = _create_context(
+ workload,
+ target,
+ multi_level_tiling_tensor_core(
+ target=target,
+ write_reuse_scope="shared",
+ in_dtype="float16",
+ out_dtype=["float16", "float32"],
+ ),
+ )
+ check_trace(spaces, expected)
+ spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
+ assert len(spaces) == 1
+
if __name__ == "__main__":
tvm.testing.main()